Skip to content

Commit

Permalink
Fix query to align with Qdrant mixin usage (#115)
Browse files Browse the repository at this point in the history
* fix: query in text_embedding_base to work with both Iterable and str as users might supply both

* Fix Qdrant query to align with future usage

* * refactor(text_embedding_base.py): change query parameter type from str to Union[str, Iterable[str]] in query_embed method

* Update return type of query_embed method

* Update return type in TextEmbeddingBase
  • Loading branch information
NirantK authored Feb 7, 2024
1 parent 4696818 commit 973da35
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 31 deletions.
22 changes: 11 additions & 11 deletions docs/examples/Retrieval_with_FastEmbed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -105,7 +105,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -138,7 +138,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -148,7 +148,7 @@
"Rank 1: Maharana Pratap was a Rajput warrior king from Mewar\n",
"Rank 2: Maharana Pratap is considered a symbol of Rajput resistance against foreign rule\n",
"Rank 3: His legacy is celebrated in Rajasthan through festivals and monuments\n",
"Rank 4: His capital was Chittorgarh, which he lost to the Mughals\n",
"Rank 4: He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar\n",
"Rank 5: He fought against the Mughal Empire led by Akbar\n"
]
}
Expand All @@ -166,16 +166,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rank 1: He died in 1597 at the age of 57\n",
"Rank 2: His life has been depicted in various films, TV shows, and books\n",
"Rank 3: Maharana Pratap was a Rajput warrior king from Mewar\n",
"Rank 1: Maharana Pratap was a Rajput warrior king from Mewar\n",
"Rank 2: Maharana Pratap is considered a symbol of Rajput resistance against foreign rule\n",
"Rank 3: His legacy is celebrated in Rajasthan through festivals and monuments\n",
"Rank 4: He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar\n",
"Rank 5: He fought against the Mughal Empire led by Akbar\n"
]
Expand Down Expand Up @@ -213,7 +213,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.11.5"
},
"orig_nbformat": 4
},
Expand Down
33 changes: 20 additions & 13 deletions docs/examples/Usage_With_Qdrant.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,26 @@
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 77.7M/77.7M [00:05<00:00, 14.6MiB/s]\n"
]
},
{
"data": {
"text/plain": [
"['6e8fcf7e0ecc407b9b6bb011d169f629',\n",
" 'c9d26e7e0ea741b2b1082d097796b28b',\n",
" 'cf05747e7eb34d2490b1df1f8be94049',\n",
" '208c197266d547a880dfb65e46738b19',\n",
" '27bd985c5d6f49d68fc2cf73dac74199',\n",
" 'c5e929c8837f4370818c97f63996f8ef',\n",
" 'c12213c6cdac470aa2471f2d30dc4041',\n",
" '974e64a7d8624f6e9824fa7b9c94f99d',\n",
" '0129fae193c740eba092512d8e53ab4a',\n",
" '492cad6e741e4aeebb196bd818a97d17']"
"['4fa8b10c78da4b18ba0830ba8a57367a',\n",
" '2eae04b515ee4e9185a9a0e6be812bba',\n",
" 'c6039f88486f47f1835ae3b069c5823c',\n",
" 'c2c8c51e305144d1917b373125fb4d95',\n",
" '79fd23b9ec0648cdab38d1947c6b933e',\n",
" '036aa200d8c3492b8a438e4f825f5e7f',\n",
" 'c35c77f3ea37460a9a13723fb77b7367',\n",
" '6ebccbca571b40d0ab6e83e5e0f2f562',\n",
" '38048c2ccc1d4962a4f8f1bd89c8357a',\n",
" 'c6b09308360140c7b4f106af3658a31e']"
]
},
"execution_count": 4,
Expand Down Expand Up @@ -187,12 +194,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[QueryResponse(id='42', embedding=None, metadata={'document': 'Qdrant has Langchain integrations', 'source': 'Langchain-docs'}, document='Qdrant has Langchain integrations', score=0.8496814051311954), QueryResponse(id='2', embedding=None, metadata={'document': 'Qdrant also has Llama Index integrations', 'source': 'Linkedin-docs'}, document='Qdrant also has Llama Index integrations', score=0.8478494193031256)]\n"
"[QueryResponse(id=42, embedding=None, metadata={'document': 'Qdrant has Langchain integrations', 'source': 'Langchain-docs'}, document='Qdrant has Langchain integrations', score=0.8276550115796268), QueryResponse(id=2, embedding=None, metadata={'document': 'Qdrant also has Llama Index integrations', 'source': 'Linkedin-docs'}, document='Qdrant also has Llama Index integrations', score=0.8265536935180283)]\n"
]
}
],
"source": [
"search_result = client.query(collection_name=\"demo_collection\", query_text=[\"This is a query document\"])\n",
"search_result = client.query(collection_name=\"demo_collection\", query_text=\"This is a query document\")\n",
"print(search_result)"
]
},
Expand Down Expand Up @@ -226,7 +233,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.11.5"
},
"orig_nbformat": 4
},
Expand Down
16 changes: 9 additions & 7 deletions fastembed/text/text_embedding_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Iterable, List, Dict, Any
from typing import Any, Dict, Iterable, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -39,17 +39,19 @@ def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
# This is model-specific, so that different models can have specialized implementations
yield from self.embed(texts, **kwargs)

def query_embed(self, query: str, **kwargs) -> np.ndarray:
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
"""
Embeds a query
Embeds queries
Args:
query (str): The query to search for.
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
Returns:
np.ndarray: The embeddings.
Iterable[np.ndarray]: The embeddings.
"""

# This is model-specific, so that different models can have specialized implementations
query_embedding = list(self.embed([query], **kwargs))[0]
return query_embedding
if isinstance(query, str):
yield from self.embed([query], **kwargs)
if isinstance(query, Iterable):
yield from self.embed(query, **kwargs)

0 comments on commit 973da35

Please sign in to comment.