Skip to content

Commit

Permalink
Add support of insert by rows (#1434)
Browse files Browse the repository at this point in the history
Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored May 25, 2023
1 parent f9d5cdd commit 72761ed
Show file tree
Hide file tree
Showing 16 changed files with 788 additions and 350 deletions.
100 changes: 100 additions & 0 deletions examples/dynamic_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import time
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "\n=== {:30} ===\n"
dim = 8

print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")
if has:
utility.drop_collection("hello_milvus")

fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs", enable_dynamic_field=True)

print(fmt.format("Create collection `hello_milvus`"))
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong")

################################################################################
# 3. insert data
hello_milvus2 = Collection("hello_milvus")
print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)

rows = [
{"pk": "1", "random": 1.0, "embeddings": rng.random((1, dim))[0], "a": 1},
{"pk": "2", "random": 1.0, "embeddings": rng.random((1, dim))[0], "b": 1},
{"pk": "3", "random": 1.0, "embeddings": rng.random((1, dim))[0], "c": 1},
{"pk": "4", "random": 1.0, "embeddings": rng.random((1, dim))[0], "d": 1},
{"pk": "5", "random": 1.0, "embeddings": rng.random((1, dim))[0], "e": 1},
{"pk": "6", "random": 1.0, "embeddings": rng.random((1, dim))[0], "f": 1},
]

insert_result = hello_milvus.insert(rows)

hello_milvus.insert({"pk": "7", "random": 1.0, "embeddings": rng.random((1, dim))[0], "g": 1})
hello_milvus.flush()
print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entites

# 4. create index
print(fmt.format("Start Creating index IVF_FLAT"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

hello_milvus.create_index("embeddings", index)

print(fmt.format("Start loading"))
hello_milvus.load()
# -----------------------------------------------------------------------------
# search based on vector similarity
print(fmt.format("Start searching based on vector similarity"))

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "embeddings"])
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}")


result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["pk", "embeddings", "$meta"])
for hits in result:
for hit in hits:
print(f"hit: {hit}")

expr = f'pk in ["1" , "2"] || g == 1'

print(fmt.format(f"Start query with expr `{expr}`"))
result = hello_milvus.query(expr=expr, output_fields=["random", "a", "g"])
for hit in result:
print("hit:", hit)

###############################################################################
# 7. drop collection
print(fmt.format("Drop collection `hello_milvus`"))
utility.drop_collection("hello_milvus")
73 changes: 28 additions & 45 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import abc

import numpy as np
from ..settings import Config
from .types import DataType
from .constants import DEFAULT_CONSISTENCY_LEVEL
from ..grpc_gen import schema_pb2
from ..exceptions import MilvusException
from . import entity_helper


class LoopBase:
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(self, raw):
self.params = {}
self.is_partition_key = False
self.default_value = None
self.is_dynamic = False

##
self.__pack(self._raw)
Expand All @@ -83,6 +84,11 @@ def __pack(self, raw):
self.type = raw.data_type
self.is_partition_key = raw.is_partition_key
self.default_value = raw.default_value
try:
self.is_dynamic = raw.is_dynamic
except Exception:
self.is_dynamic = False

# self.type = DataType(int(raw.type))

for type_param in raw.type_params:
Expand Down Expand Up @@ -115,7 +121,8 @@ def dict(self):
"is_primary": self.is_primary,
"auto_id": self.auto_id,
"is_partition_key": self.is_partition_key,
"default_value": self.default_value
"default_value": self.default_value,
"is_dynamic": self.is_dynamic,
}
return _dict

Expand All @@ -137,6 +144,7 @@ def __init__(self, raw):
self.properties = {}
self.num_shards = 0
self.num_partitions = 0
self.enable_dynamic_field = False

#
if self._raw:
Expand All @@ -156,6 +164,11 @@ def __pack(self, raw):
except Exception:
self.consistency_level = DEFAULT_CONSISTENCY_LEVEL

try:
self.enable_dynamic_field = raw.schema.enable_dynamic_field
except Exception:
self.enable_dynamic_field = False

# self.params = dict()
# TODO: extra_params here
# for kv in raw.extra_params:
Expand Down Expand Up @@ -184,6 +197,7 @@ def dict(self):
"consistency_level": self.consistency_level,
"properties": self.properties,
"num_partitions": self.num_partitions,
"enable_dynamic_field": self.enable_dynamic_field,
}
return _dict

Expand Down Expand Up @@ -264,6 +278,11 @@ def __init__(self, raw, auto_id, round_decimal=-1):
else:
self._distances = self._raw.scores

self._dynamic_field_name = None
self._dynamic_fields = set()
self._dynamic_field_name, self._dynamic_fields = entity_helper.extract_dynamic_field_from_result(self._raw)


def __len__(self):
if self._raw.ids.HasField("int_id"):
return len(self._raw.ids.int_id.data)
Expand All @@ -278,48 +297,9 @@ def get__item(self, item):
entity_id = self._raw.ids.str_id.data[item]
else:
raise MilvusException(message="Unsupported ids type")
entity_row_data = {}
if self._raw.fields_data:
for field_data in self._raw.fields_data:
if field_data.type == DataType.BOOL:
if len(field_data.scalars.bool_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.bool_data.data[item]
elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32):
if len(field_data.scalars.int_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.int_data.data[item]
elif field_data.type == DataType.INT64:
if len(field_data.scalars.long_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[item]
elif field_data.type == DataType.FLOAT:
if len(field_data.scalars.float_data.data) >= item:
entity_row_data[field_data.field_name] = np.single(field_data.scalars.float_data.data[item])
elif field_data.type == DataType.DOUBLE:
if len(field_data.scalars.double_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[item]
elif field_data.type == DataType.VARCHAR:
if len(field_data.scalars.string_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.string_data.data[item]
elif field_data.type == DataType.STRING:
raise MilvusException(message="Not support string yet")
# result[field_data.field_name] = field_data.scalars.string_data.data[index]
elif field_data.type == DataType.JSON:
if len(field_data.scalars.json_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.json_data.data[item]
elif field_data.type == DataType.FLOAT_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.float_vector.data) >= item * dim:
start_pos = item * dim
end_pos = item * dim + dim
entity_row_data[field_data.field_name] = [np.single(x) for x in
field_data.vectors.float_vector.data[
start_pos:end_pos]]
elif field_data.type == DataType.BINARY_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.binary_vector.data) >= item * (dim / 8):
start_pos = item * (dim / 8)
end_pos = (item + 1) * (dim / 8)
entity_row_data[field_data.field_name] = [
field_data.vectors.binary_vector.data[start_pos:end_pos]]

entity_row_data = entity_helper.extract_row_data_from_fields_data(self._raw.fields_data, item,
self._dynamic_fields)
entity_score = self._distances[item]
return Hit(entity_id, entity_row_data, entity_score)

Expand Down Expand Up @@ -387,7 +367,7 @@ def err_index(self):

def __str__(self):
return f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, " \
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})"
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})"

__repr__ = __str__

Expand Down Expand Up @@ -497,6 +477,7 @@ def _pack(self, raw_list):
self._nq += nq
self._topk = raw.results.top_k
offset = 0

for i in range(nq):
hit = schema_pb2.SearchResultData()
start_pos = offset
Expand All @@ -506,10 +487,12 @@ def _pack(self, raw_list):
hit.ids.int_id.data.extend(raw.results.ids.int_id.data[start_pos: end_pos])
elif raw.results.ids.HasField("str_id"):
hit.ids.str_id.data.extend(raw.results.ids.str_id.data[start_pos: end_pos])
hit.output_fields.extend(raw.results.output_fields)
for field_data in raw.results.fields_data:
field = schema_pb2.FieldData()
field.type = field_data.type
field.field_name = field_data.field_name
field.is_dynamic = field_data.is_dynamic
if field_data.type == DataType.BOOL:
field.scalars.bool_data.data.extend(field_data.scalars.bool_data.data[start_pos: end_pos])
elif field_data.type in (DataType.INT8, DataType.INT16, DataType.INT32):
Expand Down
Loading

0 comments on commit 72761ed

Please sign in to comment.