Skip to content

Commit

Permalink
Support upsert (#1303)
Browse files Browse the repository at this point in the history
Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Feb 20, 2023
1 parent 061c20e commit 4dd6792
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 78 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/collection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Methods
+---------------------------------------------------------------+--------------------------------------------------------------------------+
| `search() <#pymilvus.Collection.search>`_ | Vector similarity search with an optional boolean expression as filters. |
+---------------------------------------------------------------+--------------------------------------------------------------------------+
| `upsert() <#pymilvus.Collection.upsert>`_ | Upsert data of collection. |
+---------------------------------------------------------------+--------------------------------------------------------------------------+
| `query() <#pymilvus.Collection.query>`_ | Query with a set of criteria. |
+---------------------------------------------------------------+--------------------------------------------------------------------------+
| `partition() <#pymilvus.Collection.partition>`_ | Return the partition corresponding to name. |
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/partition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Methods
+--------------------------------------------+--------------------------------------------------------------------------+
| `delete() <#pymilvus.Partition.delete>`_ | Delete entities with an expression condition. |
+--------------------------------------------+--------------------------------------------------------------------------+
| `upsert() <#pymilvus.Collection.upsert>`_ |Upsert data of collection. |
+--------------------------------------------+--------------------------------------------------------------------------+
| `search() <#pymilvus.Partition.search>`_ | Vector similarity search with an optional boolean expression as filters. |
+--------------------------------------------+--------------------------------------------------------------------------+
| `query() <#pymilvus.Partition.query>`_ | Query with a set of criteria. |
Expand Down
58 changes: 48 additions & 10 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,22 +353,31 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k

raise MilvusException(status.error_code, status.reason)

def _prepare_batch_insert_request(self, collection_name, entities, partition_name=None, timeout=None, **kwargs):
insert_param = kwargs.get('insert_param', None)

if insert_param and not isinstance(insert_param, milvus_types.RowBatch):
raise ParamError(message="The value of key 'insert_param' is invalid")
def _prepare_batch_insert_or_upsert_request(self, collection_name, entities, partition_name=None, timeout=None, isInsert=True, **kwargs):
param = kwargs.get('insert_param', None)

if not isInsert:
param = kwargs.get('upsert_param', None)

if param and not isinstance(param, milvus_types.RowBatch):
if isInsert:
raise ParamError(
message="The value of key 'insert_param' is invalid")
raise ParamError(
message="The value of key 'upsert_param' is invalid")
if not isinstance(entities, list):
raise ParamError(message="None entities, please provide valid entities.")
raise ParamError(
message="None entities, please provide valid entities.")

collection_schema = kwargs.get("schema", None)
if not collection_schema:
collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs)
collection_schema = self.describe_collection(
collection_name, timeout=timeout, **kwargs)

fields_info = collection_schema["fields"]

request = insert_param if insert_param \
else Prepare.batch_insert_param(collection_name, entities, partition_name, fields_info)
request = param if param \
else Prepare.batch_insert_or_upsert_param(collection_name, entities, partition_name, fields_info, isInsert)

return request

Expand All @@ -378,7 +387,8 @@ def batch_insert(self, collection_name, entities, partition_name=None, timeout=N
raise ParamError(message="Invalid binary vector data exists")

try:
request = self._prepare_batch_insert_request(collection_name, entities, partition_name, timeout, **kwargs)
request = self._prepare_batch_insert_or_upsert_request(
collection_name, entities, partition_name, timeout, **kwargs)
rf = self._stub.Insert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
cb = kwargs.get("_callback", None)
Expand Down Expand Up @@ -423,6 +433,34 @@ def delete(self, collection_name, expression, partition_name=None, timeout=None,
return MutationFuture(None, None, err)
raise err

@retry_on_rpc_failure()
def upsert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs):
if not check_invalid_binary_vector(entities):
raise ParamError(message="Invalid binary vector data exists")

try:
request = self._prepare_batch_insert_or_upsert_request(
collection_name, entities, partition_name, timeout, False, **kwargs)
rf = self._stub.Upsert.future(request, timeout=timeout)
if kwargs.get("_async", False) is True:
cb = kwargs.get("_callback", None)
f = MutationFuture(rf, cb, timeout=timeout, **kwargs)
f.add_callback(ts_utils.update_ts_on_mutation(collection_name))
return f

response = rf.result()
if response.status.error_code == 0:
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
return m

raise MilvusException(
response.status.error_code, response.status.reason)
except Exception as err:
if kwargs.get("_async", False):
return MutationFuture(None, None, err)
raise err

def _execute_search_requests(self, requests, timeout=None, **kwargs):
auto_id = kwargs.get("auto_id", True)

Expand Down
62 changes: 14 additions & 48 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import entity_helper
from .check import check_pass_param, is_legal_collection_properties
from .types import DataType, PlaceholderType, get_consistency_level
from .utils import traverse_info
from .constants import DEFAULT_CONSISTENCY_LEVEL
from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage
from ..orm.schema import CollectionSchema
Expand Down Expand Up @@ -247,58 +248,23 @@ def partition_name(cls, collection_name, partition_name):
tag=partition_name)

@classmethod
def batch_insert_param(cls, collection_name, entities, partition_name, fields_info=None, **kwargs):
# insert_request.hash_keys won't be filled in client. It will be filled in proxy.
def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, fields_info=None, isInsert=True, **kwargs):
# insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. It will be filled in proxy.

tag = partition_name or "_default" # should here?
insert_request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag)
tag = partition_name or "_default" # should here?
request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag)

if not isInsert:
request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag)

for entity in entities:
if not entity.get("name", None) or not entity.get("values", None) or not entity.get("type", None):
raise ParamError(message="Missing param in entities, a field must have type, name and values")
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc, auto_id_loc = {}, None, None
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
primary_key_loc = i

if field.get("auto_id", False):
auto_id_loc = i
continue

match_flag = False
field_name = field["name"]
field_type = field["type"]

for j, entity in enumerate(entities):
entity_name, entity_type = entity["name"], entity["type"]

if field_name == entity_name:
if field_type != entity_type:
raise ParamError(message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}")

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}")

if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}")

location[field["name"]] = j
match_flag = True
break

if not match_flag:
raise ParamError(message=f"Field {field['name']} don't match in entities")
traverse_info(fields_info, entities, location, primary_key_loc, auto_id_loc)

# though impossible from sdk
if primary_key_loc is None:
Expand All @@ -318,13 +284,13 @@ def batch_insert_param(cls, collection_name, entities, partition_name, fields_in
raise ParamError(message="row num misaligned current[{current}]!= previous[{row_num}]")
row_num = current
field_data = entity_helper.entity_to_field_data(entity, fields_info[location[entity.get("name")]])
insert_request.fields_data.append(field_data)
request.fields_data.append(field_data)
except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

insert_request.num_rows = row_num
request.num_rows = row_num

return insert_request
return request

@classmethod
def delete_request(cls, collection_name, partition_name, expr):
Expand Down
45 changes: 44 additions & 1 deletion pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .types import DataType
from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK
from ..exceptions import MilvusException
from ..exceptions import ParamError, MilvusException

valid_index_types = [
"FLAT",
Expand Down Expand Up @@ -149,3 +149,46 @@ def len_of(field_data) -> int:
return int(total_len / (dim / 8))

raise MilvusException(message="Unknown data type")


def traverse_info(fields_info, entities, location, primary_key_loc, auto_id_loc):
for i, field in enumerate(fields_info):
if field.get("is_primary", False):
primary_key_loc = i

if field.get("auto_id", False):
auto_id_loc = i
continue

match_flag = False
field_name = field["name"]
field_type = field["type"]

for j, entity in enumerate(entities):
entity_name, entity_type = entity["name"], entity["type"]

if field_name == entity_name:
if field_type != entity_type:
raise ParamError(message=f"Collection field type is {field_type}"
f", but entities field type is {entity_type}")

entity_dim, field_dim = 0, 0
if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
field_dim = field["params"]["dim"]
entity_dim = len(entity["values"][0])

if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim}")

if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim:
raise ParamError(message=f"Collection field dim is {field_dim}"
f", but entities field dim is {entity_dim * 8}")

location[field["name"]] = j
match_flag = True
break

if not match_flag:
raise ParamError(
message=f"Field {field['name']} don't match in entities")
5 changes: 5 additions & 0 deletions pymilvus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class InvalidConsistencyLevel(MilvusException):
""" Raise when consistency level is invalid """


class UpsertAutoIDTrueException(MilvusException):
""" Raise when upsert autoID is true """


class ExceptionsMessage:
NoHostPort = "connection configuration must contain 'host' and 'port'."
HostType = "Type of 'host' must be str."
Expand Down Expand Up @@ -165,3 +169,4 @@ class ExceptionsMessage:
ExprType = "The type of expr must be string ,but %r is given."
EnvConfigErr = "Environment variable %s has a wrong format, please check it: %s"
AmbiguousIndexName = "There are multiple indexes, please specify the index_name."
UpsertAutoIDTrue = "Upsert don't support autoid == true"
56 changes: 51 additions & 5 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

import copy
import json
from typing import List
from typing import List, Union
import pandas

from .connections import connections
from .schema import (
CollectionSchema,
FieldSchema,
parse_fields_from_data,
check_insert_data_schema,
check_insert_or_upsert_data_schema,
check_schema,
)
from .prepare import Prepare
Expand Down Expand Up @@ -388,7 +388,7 @@ def release(self, timeout=None, **kwargs):
conn = self._get_connection()
conn.release_collection(self._name, timeout=timeout, **kwargs)

def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
def insert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
""" Insert data into the collection.
Args:
Expand Down Expand Up @@ -423,8 +423,8 @@ def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeo
"""
if data is None:
return MutationResult(data)
check_insert_data_schema(self._schema, data)
entities = Prepare.prepare_insert_data(data, self._schema)
check_insert_or_upsert_data_schema(self._schema, data)
entities = Prepare.prepare_insert_or_upsert_data(data, self._schema)

conn = self._get_connection()
res = conn.batch_insert(self._name, entities, partition_name,
Expand Down Expand Up @@ -477,6 +477,52 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs):
return MutationFuture(res)
return MutationResult(res)

def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
""" Upsert data into the collection.
Args:
data (``list/tuple/pandas.DataFrame``): The specified data to upsert
partition_name (``str``): The partition name which the data will be upserted at,
if partition name is not passed, then the data will be upserted in "_default" partition
timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None.
If timeout is set to None, the client keeps waiting until the server responds or an error occurs.
Returns:
MutationResult: contains 2 properties `upsert_count`, and, `primary_keys`
`upsert_count`: how may entites have been upserted at Milvus,
`primary_keys`: list of primary keys of the upserted entities
Raises:
MilvusException: If anything goes wrong.
Examples:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_upsert", schema)
>>> data = [
... [random.randint(1, 100) for _ in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> res = collection.upsert(data)
>>> res.upsert_count
10
"""
if data is None:
return MutationResult(data)
check_insert_or_upsert_data_schema(self._schema, data, False)
entities = Prepare.prepare_insert_or_upsert_data(data, self._schema, False)

conn = self._get_connection()
res = conn.upsert(self._name, entities, partition_name,
timeout=timeout, schema=self._schema_dict, **kwargs)

if kwargs.get("_async", False):
return MutationFuture(res)
return MutationResult(res)

def search(self, data, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
""" Conducts a vector similarity search with an optional boolean expression as filter.
Expand Down
Loading

0 comments on commit 4dd6792

Please sign in to comment.