-
Notifications
You must be signed in to change notification settings - Fork 335
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add function impl for pymilvus, specifically BM25 function
Signed-off-by: Buqian Zheng <[email protected]>
- Loading branch information
1 parent
b51ebce
commit 2cecf9d
Showing
12 changed files
with
843 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# hello_bm25.py demonstrates how to insert raw data only into Milvus and perform | ||
# sparse vector based ANN search using BM25 algorithm. | ||
# 1. connect to Milvus | ||
# 2. create collection | ||
# 3. insert data | ||
# 4. create index | ||
# 5. search, query, and filtering search on entities | ||
# 6. delete entities by PK | ||
# 7. drop collection | ||
import time | ||
|
||
from pymilvus import ( | ||
connections, | ||
utility, | ||
FieldSchema, CollectionSchema, Function, DataType, FunctionType, | ||
Collection, | ||
) | ||
|
||
fmt = "\n=== {:30} ===\n" | ||
search_latency_fmt = "search latency = {:.4f}s" | ||
|
||
################################################################################# | ||
# 1. connect to Milvus | ||
# Add a new connection alias `default` for Milvus server in `localhost:19530` | ||
print(fmt.format("start connecting to Milvus")) | ||
connections.connect("default", host="localhost", port="19530") | ||
|
||
has = utility.has_collection("hello_bm25") | ||
print(f"Does collection hello_bm25 exist in Milvus: {has}") | ||
|
||
################################################################################# | ||
# 2. create collection | ||
# We're going to create a collection with 2 explicit fields and a function. | ||
# +-+------------+------------+------------------+------------------------------+ | ||
# | | field name | field type | other attributes | field description | | ||
# +-+------------+------------+------------------+------------------------------+ | ||
# |1| "id" | INT64 | is_primary=True | "primary field" | | ||
# | | | | auto_id=False | | | ||
# +-+------------+------------+------------------+------------------------------+ | ||
# |2| "document" | VarChar | | "raw text document" | | ||
# +-+------------+------------+------------------+------------------------------+ | ||
# | ||
# Function 'bm25' is used to convert raw text document to a sparse vector representation | ||
# and store it in the 'sparse' field. | ||
# +-+------------+-------------------+-----------+------------------------------+ | ||
# | | field name | field type | other attr| field description | | ||
# +-+------------+-------------------+-----------+------------------------------+ | ||
# |3| "sparse" |SPARSE_FLOAT_VECTOR| | | | ||
# +-+------------+-------------------+-----------+------------------------------+ | ||
# | ||
fields = [ | ||
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), | ||
FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR), | ||
FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000, enable_tokenizer=True), | ||
] | ||
|
||
bm25_function = Function( | ||
name="bm25", | ||
function_type=FunctionType.BM25, | ||
input_field_names=["document"], | ||
output_field_names="sparse", | ||
) | ||
|
||
schema = CollectionSchema(fields, "hello_bm25 demo") | ||
schema.add_function(bm25_function) | ||
|
||
print(fmt.format("Create collection `hello_bm25`")) | ||
hello_bm25 = Collection("hello_bm25", schema, consistency_level="Strong") | ||
|
||
################################################################################ | ||
# 3. insert data | ||
# We are going to insert 3 rows of data into `hello_bm25` | ||
# Data to be inserted must be organized in fields. | ||
# | ||
# The insert() method returns: | ||
# - either automatically generated primary keys by Milvus if auto_id=True in the schema; | ||
# - or the existing primary key field from the entities if auto_id=False in the schema. | ||
|
||
print(fmt.format("Start inserting entities")) | ||
|
||
num_entities = 6 | ||
|
||
entities = [ | ||
[f"This is a test document {i + hello_bm25.num_entities}" for i in range(num_entities)], | ||
] | ||
|
||
insert_result = hello_bm25.insert(entities) | ||
ids = insert_result.primary_keys | ||
|
||
time.sleep(3) | ||
|
||
hello_bm25.flush() | ||
print(f"Number of entities in Milvus: {hello_bm25.num_entities}") # check the num_entities | ||
|
||
################################################################################ | ||
# 4. create index | ||
# We are going to create an index for hello_bm25 collection, here we simply | ||
# uses AUTOINDEX so Milvus can use the default parameters. | ||
print(fmt.format("Start Creating index AUTOINDEX")) | ||
index = { | ||
"index_type": "AUTOINDEX", | ||
"metric_type": "BM25", | ||
} | ||
|
||
hello_bm25.create_index("sparse", index) | ||
|
||
################################################################################ | ||
# 5. search, query, and scalar filtering search | ||
# After data were inserted into Milvus and indexed, you can perform: | ||
# - search texts relevance by BM25 using sparse vector ANN search | ||
# - query based on scalar filtering(boolean, int, etc.) | ||
# - scalar filtering search. | ||
# | ||
|
||
# Before conducting a search or a query, you need to load the data in `hello_bm25` into memory. | ||
print(fmt.format("Start loading")) | ||
hello_bm25.load() | ||
|
||
# ----------------------------------------------------------------------------- | ||
print(fmt.format("Start searching based on BM25 texts relevance using sparse vector ANN search")) | ||
texts_to_search = entities[-1][-2:] | ||
print(fmt.format(f"texts_to_search: {texts_to_search}")) | ||
search_params = { | ||
"metric_type": "BM25", | ||
"params": {}, | ||
} | ||
|
||
start_time = time.time() | ||
result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, output_fields=["document"], consistency_level="Strong") | ||
end_time = time.time() | ||
|
||
for hits, text in zip(result, texts_to_search): | ||
print(f"result of text: {text}") | ||
for hit in hits: | ||
print(f"\thit: {hit}, document field: {hit.entity.get('document')}") | ||
print(search_latency_fmt.format(end_time - start_time)) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# query based on scalar filtering(boolean, int, etc.) | ||
filter_id = ids[num_entities // 2 - 1] | ||
print(fmt.format(f"Start querying with `id > {filter_id}`")) | ||
|
||
start_time = time.time() | ||
result = hello_bm25.query(expr=f"id > {filter_id}", output_fields=["document"]) | ||
end_time = time.time() | ||
|
||
print(f"query result:\n-{result[0]}") | ||
print(search_latency_fmt.format(end_time - start_time)) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# pagination | ||
r1 = hello_bm25.query(expr=f"id > {filter_id}", limit=3, output_fields=["document"]) | ||
r2 = hello_bm25.query(expr=f"id > {filter_id}", offset=1, limit=2, output_fields=["document"]) | ||
print(f"query pagination(limit=3):\n\t{r1}") | ||
print(f"query pagination(offset=1, limit=2):\n\t{r2}") | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# scalar filtering search | ||
print(fmt.format(f"Start filtered searching with `id > {filter_id}`")) | ||
|
||
start_time = time.time() | ||
result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, expr=f"id > {filter_id}", output_fields=["document"]) | ||
end_time = time.time() | ||
|
||
for hits, text in zip(result, texts_to_search): | ||
print(f"result of text: {text}") | ||
for hit in hits: | ||
print(f"\thit: {hit}, document field: {hit.entity.get('document')}") | ||
print(search_latency_fmt.format(end_time - start_time)) | ||
|
||
############################################################################### | ||
# 6. delete entities by PK | ||
# You can delete entities by their PK values using boolean expressions. | ||
|
||
expr = f'id in [{ids[0]}, {ids[1]}]' | ||
print(fmt.format(f"Start deleting with expr `{expr}`")) | ||
|
||
result = hello_bm25.query(expr=expr, output_fields=["document"]) | ||
print(f"query before delete by expr=`{expr}` -> result: \n- {result[0]}\n- {result[1]}\n") | ||
|
||
hello_bm25.delete(expr) | ||
|
||
result = hello_bm25.query(expr=expr, output_fields=["document"]) | ||
print(f"query after delete by expr=`{expr}` -> result: {result}\n") | ||
|
||
############################################################################### | ||
# 7. upsert by PK | ||
# You can upsert data to replace existing data. | ||
target_id = ids[2] | ||
print(fmt.format(f"Start upsert operation for id {target_id}")) | ||
|
||
# Query before upsert | ||
result_before = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) | ||
print(f"Query before upsert (id={target_id}):\n{result_before}") | ||
|
||
# Prepare data for upsert | ||
upsert_data = [ | ||
[target_id], | ||
["This is an upserted document for testing purposes."] | ||
] | ||
|
||
# Perform upsert operation | ||
hello_bm25.upsert(upsert_data) | ||
|
||
# Query after upsert | ||
result_after = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"]) | ||
print(f"Query after upsert (id={target_id}):\n{result_after}") | ||
|
||
|
||
############################################################################### | ||
# 7. drop collection | ||
# Finally, drop the hello_bm25 collection | ||
print(fmt.format("Drop collection `hello_bm25`")) | ||
utility.drop_collection("hello_bm25") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# A demo showing hybrid semantic search with dense and full text search with BM25 | ||
# using Milvus. | ||
# | ||
# You can optionally choose to use the BGE-M3 model to embed the text as dense | ||
# vectors, or simply use random generated vectors as an example. | ||
# | ||
# You can also use the BGE CrossEncoder model to rerank the search results. | ||
# | ||
# Note that the full text search feature is only available in Milvus 2.4.0 or | ||
# higher version. Make sure you follow https://milvus.io/docs/install_standalone-docker.md | ||
# to set up the latest version of Milvus in your local environment. | ||
|
||
# To connect to Milvus server, you need the python client library called pymilvus. | ||
# To use BGE-M3 model, you need to install the optional `model` module in pymilvus. | ||
# You can get them by simply running the following commands: | ||
# | ||
# pip install pymilvus | ||
# pip install pymilvus[model] | ||
|
||
# If true, use BGE-M3 model to generate dense vectors. | ||
# If false, use random numbers to compose dense vectors. | ||
use_bge_m3 = False | ||
# If true, the search result will be reranked using BGE CrossEncoder model. | ||
use_reranker = False | ||
|
||
# The overall steps are as follows: | ||
# 1. embed the text as dense and sparse vectors | ||
# 2. setup a Milvus collection to store the dense and sparse vectors | ||
# 3. insert the data to Milvus | ||
# 4. search and inspect the result! | ||
import random | ||
import string | ||
import numpy as np | ||
|
||
from pymilvus import ( | ||
utility, | ||
FieldSchema, | ||
CollectionSchema, | ||
DataType, | ||
Collection, | ||
AnnSearchRequest, | ||
RRFRanker, | ||
connections, | ||
Function, | ||
FunctionType, | ||
) | ||
|
||
# 1. prepare a small corpus to search | ||
docs = [ | ||
"Artificial intelligence was founded as an academic discipline in 1956.", | ||
"Alan Turing was the first person to conduct substantial research in AI.", | ||
"Born in Maida Vale, London, Turing was raised in southern England.", | ||
] | ||
# add some randomly generated texts | ||
docs.extend( | ||
[ | ||
" ".join( | ||
"".join(random.choice(string.ascii_lowercase) for _ in range(random.randint(1, 8))) | ||
for _ in range(10) | ||
) | ||
for _ in range(1000) | ||
] | ||
) | ||
query = "Who started AI research?" | ||
|
||
|
||
def random_embedding(texts): | ||
rng = np.random.default_rng() | ||
return { | ||
"dense": np.random.rand(len(texts), 768), | ||
} | ||
|
||
|
||
dense_dim = 768 | ||
ef = random_embedding | ||
|
||
if use_bge_m3: | ||
# BGE-M3 model is included in the optional `model` module in pymilvus, to | ||
# install it, simply run "pip install pymilvus[model]". | ||
from pymilvus.model.hybrid import BGEM3EmbeddingFunction | ||
|
||
ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu") | ||
dense_dim = ef.dim["dense"] | ||
|
||
docs_embeddings = ef(docs) | ||
query_embeddings = ef([query]) | ||
|
||
# 2. setup Milvus collection and index | ||
connections.connect("default", host="localhost", port="19530") | ||
|
||
# Specify the data schema for the new Collection. | ||
fields = [ | ||
# Use auto generated id as primary key | ||
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), | ||
# Store the original text to retrieve based on semantically distance | ||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512, enable_tokenizer=True), | ||
# We need a sparse vector field to perform full text search with BM25, | ||
# but you don't need to provide data for it when inserting data. | ||
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), | ||
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim), | ||
] | ||
functions = [ | ||
Function( | ||
name="bm25", | ||
function_type=FunctionType.BM25, | ||
input_field_names=["text"], | ||
output_field_names="sparse_vector", | ||
) | ||
] | ||
schema = CollectionSchema(fields, "", functions=functions) | ||
col_name = "hybrid_bm25_demo" | ||
# Now we can create the new collection with above name and schema. | ||
col = Collection(col_name, schema, consistency_level="Strong") | ||
|
||
# We need to create indices for the vector fields. The indices will be loaded | ||
# into memory for efficient search. | ||
sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"} | ||
col.create_index("sparse_vector", sparse_index) | ||
dense_index = {"index_type": "FLAT", "metric_type": "IP"} | ||
col.create_index("dense_vector", dense_index) | ||
col.load() | ||
|
||
# 3. insert text and sparse/dense vector representations into the collection | ||
entities = [docs, docs_embeddings["dense"]] | ||
col.insert(entities) | ||
col.flush() | ||
|
||
# 4. search and inspect the result! | ||
k = 2 # we want to get the top 2 docs closest to the query | ||
|
||
# Prepare the search requests for both full text search and dense vector search | ||
full_text_search_params = {"metric_type": "BM25"} | ||
# provide raw text query for full text search, while use the sparse vector as | ||
# ANNS field | ||
full_text_search_req = AnnSearchRequest([query], "sparse_vector", full_text_search_params, limit=k) | ||
dense_search_params = {"metric_type": "IP"} | ||
dense_req = AnnSearchRequest( | ||
query_embeddings["dense"], "dense_vector", dense_search_params, limit=k | ||
) | ||
|
||
# Search topK docs based on dense and sparse vectors and rerank with RRF. | ||
res = col.hybrid_search( | ||
[full_text_search_req, dense_req], rerank=RRFRanker(), limit=k, output_fields=["text"] | ||
) | ||
|
||
# Currently Milvus only support 1 query in the same hybrid search request, so | ||
# we inspect res[0] directly. In future release Milvus will accept batch | ||
# hybrid search queries in the same call. | ||
res = res[0] | ||
|
||
if use_reranker: | ||
result_texts = [hit.fields["text"] for hit in res] | ||
from pymilvus.model.reranker import BGERerankFunction | ||
|
||
bge_rf = BGERerankFunction(device="cpu") | ||
# rerank the results using BGE CrossEncoder model | ||
results = bge_rf(query, result_texts, top_k=2) | ||
for hit in results: | ||
print(f"text: {hit.text} distance {hit.score}") | ||
else: | ||
for hit in res: | ||
print(f'text: {hit.fields["text"]} distance {hit.distance}') | ||
|
||
# If you used both BGE-M3 and the reranker, you should see the following: | ||
# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.9306981017573297 | ||
# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.03217001154515051 | ||
# | ||
# If you used only BGE-M3, you should see the following: | ||
# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.032786883413791656 | ||
# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.016129031777381897 | ||
|
||
# In this simple example the reranker yields the same result as the embedding based hybrid search, but in more complex | ||
# scenarios the reranker can provide more accurate results. | ||
|
||
# If you used random vectors, the result will be different each time you run the script. | ||
|
||
# Drop the collection to clean up the data. | ||
utility.drop_collection(col_name) |
Oops, something went wrong.