Skip to content

Commit

Permalink
add function impl for pymilvus, specifically BM25 function
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Sep 13, 2024
1 parent a7b14a2 commit 65c4fea
Show file tree
Hide file tree
Showing 13 changed files with 790 additions and 33 deletions.
217 changes: 217 additions & 0 deletions examples/hello_bm25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# hello_bm25.py demonstrates how to insert raw data only into Milvus and perform
# sparse vector based ANN search using BM25 algorithm.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and filtering search on entities
# 6. delete entities by PK
# 7. drop collection
import time

from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, Function, DataType, FunctionType,
Collection,
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"

#################################################################################
# 1. connect to Milvus
# Add a new connection alias `default` for Milvus server in `localhost:19530`
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_bm25")
print(f"Does collection hello_bm25 exist in Milvus: {has}")

#################################################################################
# 2. create collection
# We're going to create a collection with 2 explicit fields and a function.
# +-+------------+------------+------------------+------------------------------+
# | | field name | field type | other attributes | field description |
# +-+------------+------------+------------------+------------------------------+
# |1| "id" | INT64 | is_primary=True | "primary field" |
# | | | | auto_id=False | |
# +-+------------+------------+------------------+------------------------------+
# |2| "document" | VarChar | | "raw text document" |
# +-+------------+------------+------------------+------------------------------+
#
# Function 'bm25' is used to convert raw text document to a sparse vector representation
# and store it in the 'sparse' field.
# +-+------------+-------------------+-----------+------------------------------+
# | | field name | field type | other attr| field description |
# +-+------------+-------------------+-----------+------------------------------+
# |3| "sparse" |SPARSE_FLOAT_VECTOR| | |
# +-+------------+-------------------+-----------+------------------------------+
#
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="document", dtype=DataType.VARCHAR, max_length=1000),
]

bm25_function = Function(
name="bm25",
function_type=FunctionType.BM25,
inputs=["document"],
outputs=["sparse"],
params={"bm25_k1": 1.2, "bm25_b": 0.75},
)

schema = CollectionSchema(fields, "hello_bm25 demo")
schema.add_function(bm25_function)

print(fmt.format("Create collection `hello_bm25`"))
hello_bm25 = Collection("hello_bm25", schema, consistency_level="Strong")

################################################################################
# 3. insert data
# We are going to insert 3 rows of data into `hello_bm25`
# Data to be inserted must be organized in fields.
#
# The insert() method returns:
# - either automatically generated primary keys by Milvus if auto_id=True in the schema;
# - or the existing primary key field from the entities if auto_id=False in the schema.

print(fmt.format("Start inserting entities"))

num_entities = 6

entities = [
[f"This is a test document {i + hello_bm25.num_entities}" for i in range(num_entities)],
]

insert_result = hello_bm25.insert(entities)
ids = insert_result.primary_keys

time.sleep(3)

hello_bm25.flush()
print(f"Number of entities in Milvus: {hello_bm25.num_entities}") # check the num_entities

################################################################################
# 4. create index
# We are going to create an SPARSE_INVERTED_INDEX index for hello_bm25 collection.
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
print(fmt.format("Start Creating index SPARSE_INVERTED_INDEX"))
index = {
"index_type": "SPARSE_INVERTED_INDEX",
"metric_type": "BM25",
'params': {"bm25_k1": 1.2, "bm25_b": 0.75},
}

hello_bm25.create_index("sparse", index)

################################################################################
# 5. search, query, and scalar filtering search
# After data were inserted into Milvus and indexed, you can perform:
# - search texts relevance by BM25 using sparse vector ANN search
# - query based on scalar filtering(boolean, int, etc.)
# - scalar filtering search.
#

# Before conducting a search or a query, you need to load the data in `hello_bm25` into memory.
print(fmt.format("Start loading"))
hello_bm25.load()

# -----------------------------------------------------------------------------
print(fmt.format("Start searching based on BM25 texts relevance using sparse vector ANN search"))
texts_to_search = entities[-1][-2:]
print(fmt.format(f"texts_to_search: {texts_to_search}"))
search_params = {
"metric_type": "BM25",
"params": {},
}

start_time = time.time()
result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, output_fields=["document"], consistency_level="Strong")
end_time = time.time()

for hits, text in zip(result, texts_to_search):
print(f"result of text: {text}")
for hit in hits:
print(f"\thit: {hit}, document field: {hit.entity.get('document')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
filter_id = ids[num_entities // 2 - 1]
print(fmt.format(f"Start querying with `id > {filter_id}`"))

start_time = time.time()
result = hello_bm25.query(expr=f"id > {filter_id}", output_fields=["document"])
end_time = time.time()

print(f"query result:\n-{result[0]}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_bm25.query(expr=f"id > {filter_id}", limit=3, output_fields=["document"])
r2 = hello_bm25.query(expr=f"id > {filter_id}", offset=1, limit=2, output_fields=["document"])
print(f"query pagination(limit=3):\n\t{r1}")
print(f"query pagination(offset=1, limit=2):\n\t{r2}")


# -----------------------------------------------------------------------------
# scalar filtering search
print(fmt.format(f"Start filtered searching with `id > {filter_id}`"))

start_time = time.time()
result = hello_bm25.search(texts_to_search, "sparse", search_params, limit=3, expr=f"id > {filter_id}", output_fields=["document"])
end_time = time.time()

for hits, text in zip(result, texts_to_search):
print(f"result of text: {text}")
for hit in hits:
print(f"\thit: {hit}, document field: {hit.entity.get('document')}")
print(search_latency_fmt.format(end_time - start_time))

###############################################################################
# 6. delete entities by PK
# You can delete entities by their PK values using boolean expressions.

expr = f'id in [{ids[0]}, {ids[1]}]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_bm25.query(expr=expr, output_fields=["document"])
print(f"query before delete by expr=`{expr}` -> result: \n- {result[0]}\n- {result[1]}\n")

hello_bm25.delete(expr)

result = hello_bm25.query(expr=expr, output_fields=["document"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")

###############################################################################
# 7. upsert by PK
# You can upsert data to replace existing data.
target_id = ids[2]
print(fmt.format(f"Start upsert operation for id {target_id}"))

# Query before upsert
result_before = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"])
print(f"Query before upsert (id={target_id}):\n{result_before}")

# Prepare data for upsert
upsert_data = [
[target_id],
["This is an upserted document for testing purposes."]
]

# Perform upsert operation
hello_bm25.upsert(upsert_data)

# Query after upsert
result_after = hello_bm25.query(expr=f"id == {target_id}", output_fields=["id", "document"])
print(f"Query after upsert (id={target_id}):\n{result_after}")


###############################################################################
# 7. drop collection
# Finally, drop the hello_bm25 collection
print(fmt.format("Drop collection `hello_bm25`"))
utility.drop_collection("hello_bm25")
179 changes: 179 additions & 0 deletions examples/hello_hybrid_bm25.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# A demo showing hybrid semantic search with dense and full text search with BM25
# using Milvus.
#
# You can optionally choose to use the BGE-M3 model to embed the text as dense
# vectors, or simply use random generated vectors as an example.
#
# You can also use the BGE CrossEncoder model to rerank the search results.
#
# Note that the full text search feature is only available in Milvus 2.4.0 or
# higher version. Make sure you follow https://milvus.io/docs/install_standalone-docker.md
# to set up the latest version of Milvus in your local environment.

# To connect to Milvus server, you need the python client library called pymilvus.
# To use BGE-M3 model, you need to install the optional `model` module in pymilvus.
# You can get them by simply running the following commands:
#
# pip install pymilvus
# pip install pymilvus[model]

# If true, use BGE-M3 model to generate dense vectors.
# If false, use random numbers to compose dense vectors.
use_bge_m3 = False
# If true, the search result will be reranked using BGE CrossEncoder model.
use_reranker = False

# The overall steps are as follows:
# 1. embed the text as dense and sparse vectors
# 2. setup a Milvus collection to store the dense and sparse vectors
# 3. insert the data to Milvus
# 4. search and inspect the result!
import random
import string
import numpy as np

from pymilvus import (
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
AnnSearchRequest,
RRFRanker,
connections,
Function,
FunctionType,
)

# 1. prepare a small corpus to search
docs = [
"Artificial intelligence was founded as an academic discipline in 1956.",
"Alan Turing was the first person to conduct substantial research in AI.",
"Born in Maida Vale, London, Turing was raised in southern England.",
]
# add some randomly generated texts
docs.extend(
[
" ".join(
"".join(random.choice(string.ascii_lowercase) for _ in range(random.randint(1, 8)))
for _ in range(10)
)
for _ in range(1000)
]
)
query = "Who started AI research?"


def random_embedding(texts):
rng = np.random.default_rng()
return {
"dense": np.random.rand(len(texts), 768),
}


dense_dim = 768
ef = random_embedding

if use_bge_m3:
# BGE-M3 model is included in the optional `model` module in pymilvus, to
# install it, simply run "pip install pymilvus[model]".
from pymilvus.model.hybrid import BGEM3EmbeddingFunction

ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu")
dense_dim = ef.dim["dense"]

docs_embeddings = ef(docs)
query_embeddings = ef([query])

# 2. setup Milvus collection and index
connections.connect("default", host="localhost", port="19530")

# Specify the data schema for the new Collection.
fields = [
# Use auto generated id as primary key
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100),
# Store the original text to retrieve based on semantically distance
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512),
# We need a sparse vector field to perform full text search with BM25,
# but you don't need to provide data for it when inserting data.
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim),
]
functions = [
Function(
name="bm25",
function_type=FunctionType.BM25,
inputs=["text"],
outputs=["sparse_vector"],
params={"bm25_k1": 1.2, "bm25_b": 0.75},
)
]
schema = CollectionSchema(fields, "", functions=functions)
col_name = "hybrid_bm25_demo"
# Now we can create the new collection with above name and schema.
col = Collection(col_name, schema, consistency_level="Strong")

# We need to create indices for the vector fields. The indices will be loaded
# into memory for efficient search.
sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
col.create_index("sparse_vector", sparse_index)
dense_index = {"index_type": "FLAT", "metric_type": "IP"}
col.create_index("dense_vector", dense_index)
col.load()

# 3. insert text and sparse/dense vector representations into the collection
entities = [docs, docs_embeddings["dense"]]
col.insert(entities)
col.flush()

# 4. search and inspect the result!
k = 2 # we want to get the top 2 docs closest to the query

# Prepare the search requests for both full text search and dense vector search
full_text_search_params = {"metric_type": "BM25"}
# provide raw text query for full text search, while use the sparse vector as
# ANNS field
full_text_search_req = AnnSearchRequest([query], "sparse_vector", full_text_search_params, limit=k)
dense_search_params = {"metric_type": "IP"}
dense_req = AnnSearchRequest(
query_embeddings["dense"], "dense_vector", dense_search_params, limit=k
)

# Search topK docs based on dense and sparse vectors and rerank with RRF.
res = col.hybrid_search(
[full_text_search_req, dense_req], rerank=RRFRanker(), limit=k, output_fields=["text"]
)

# Currently Milvus only support 1 query in the same hybrid search request, so
# we inspect res[0] directly. In future release Milvus will accept batch
# hybrid search queries in the same call.
res = res[0]

if use_reranker:
result_texts = [hit.fields["text"] for hit in res]
from pymilvus.model.reranker import BGERerankFunction

bge_rf = BGERerankFunction(device="cpu")
# rerank the results using BGE CrossEncoder model
results = bge_rf(query, result_texts, top_k=2)
for hit in results:
print(f"text: {hit.text} distance {hit.score}")
else:
for hit in res:
print(f'text: {hit.fields["text"]} distance {hit.distance}')

# If you used both BGE-M3 and the reranker, you should see the following:
# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.9306981017573297
# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.03217001154515051
#
# If you used only BGE-M3, you should see the following:
# text: Alan Turing was the first person to conduct substantial research in AI. distance 0.032786883413791656
# text: Artificial intelligence was founded as an academic discipline in 1956. distance 0.016129031777381897

# In this simple example the reranker yields the same result as the embedding based hybrid search, but in more complex
# scenarios the reranker can provide more accurate results.

# If you used random vectors, the result will be different each time you run the script.

# Drop the collection to clean up the data.
utility.drop_collection(col_name)
Loading

0 comments on commit 65c4fea

Please sign in to comment.