diff --git a/docs/source/api/collection.rst b/docs/source/api/collection.rst index da8d0ad40..98560f90c 100644 --- a/docs/source/api/collection.rst +++ b/docs/source/api/collection.rst @@ -61,6 +61,8 @@ Methods +---------------------------------------------------------------+--------------------------------------------------------------------------+ | `search() <#pymilvus.Collection.search>`_ | Vector similarity search with an optional boolean expression as filters. | +---------------------------------------------------------------+--------------------------------------------------------------------------+ +| `upsert() <#pymilvus.Collection.upsert>`_ | Upsert data of collection. | ++---------------------------------------------------------------+--------------------------------------------------------------------------+ | `query() <#pymilvus.Collection.query>`_ | Query with a set of criteria. | +---------------------------------------------------------------+--------------------------------------------------------------------------+ | `partition() <#pymilvus.Collection.partition>`_ | Return the partition corresponding to name. | diff --git a/docs/source/api/partition.rst b/docs/source/api/partition.rst index 12762d666..2d68b5e68 100644 --- a/docs/source/api/partition.rst +++ b/docs/source/api/partition.rst @@ -50,6 +50,8 @@ Methods +--------------------------------------------+--------------------------------------------------------------------------+ | `delete() <#pymilvus.Partition.delete>`_ | Delete entities with an expression condition. | +--------------------------------------------+--------------------------------------------------------------------------+ +| `upsert() <#pymilvus.Collection.upsert>`_ |Upsert data of collection. | ++--------------------------------------------+--------------------------------------------------------------------------+ | `search() <#pymilvus.Partition.search>`_ | Vector similarity search with an optional boolean expression as filters. | +--------------------------------------------+--------------------------------------------------------------------------+ | `query() <#pymilvus.Partition.query>`_ | Query with a set of criteria. | diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 82d1276f3..e583dbf20 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -353,22 +353,31 @@ def get_partition_stats(self, collection_name, partition_name, timeout=None, **k raise MilvusException(status.error_code, status.reason) - def _prepare_batch_insert_request(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): - insert_param = kwargs.get('insert_param', None) - - if insert_param and not isinstance(insert_param, milvus_types.RowBatch): - raise ParamError(message="The value of key 'insert_param' is invalid") + def _prepare_batch_insert_or_upsert_request(self, collection_name, entities, partition_name=None, timeout=None, isInsert=True, **kwargs): + param = kwargs.get('insert_param', None) + + if not isInsert: + param = kwargs.get('upsert_param', None) + + if param and not isinstance(param, milvus_types.RowBatch): + if isInsert: + raise ParamError( + message="The value of key 'insert_param' is invalid") + raise ParamError( + message="The value of key 'upsert_param' is invalid") if not isinstance(entities, list): - raise ParamError(message="None entities, please provide valid entities.") + raise ParamError( + message="None entities, please provide valid entities.") collection_schema = kwargs.get("schema", None) if not collection_schema: - collection_schema = self.describe_collection(collection_name, timeout=timeout, **kwargs) + collection_schema = self.describe_collection( + collection_name, timeout=timeout, **kwargs) fields_info = collection_schema["fields"] - request = insert_param if insert_param \ - else Prepare.batch_insert_param(collection_name, entities, partition_name, fields_info) + request = param if param \ + else Prepare.batch_insert_or_upsert_param(collection_name, entities, partition_name, fields_info, isInsert) return request @@ -378,7 +387,8 @@ def batch_insert(self, collection_name, entities, partition_name=None, timeout=N raise ParamError(message="Invalid binary vector data exists") try: - request = self._prepare_batch_insert_request(collection_name, entities, partition_name, timeout, **kwargs) + request = self._prepare_batch_insert_or_upsert_request( + collection_name, entities, partition_name, timeout, **kwargs) rf = self._stub.Insert.future(request, timeout=timeout) if kwargs.get("_async", False) is True: cb = kwargs.get("_callback", None) @@ -423,6 +433,34 @@ def delete(self, collection_name, expression, partition_name=None, timeout=None, return MutationFuture(None, None, err) raise err + @retry_on_rpc_failure() + def upsert(self, collection_name, entities, partition_name=None, timeout=None, **kwargs): + if not check_invalid_binary_vector(entities): + raise ParamError(message="Invalid binary vector data exists") + + try: + request = self._prepare_batch_insert_or_upsert_request( + collection_name, entities, partition_name, timeout, False, **kwargs) + rf = self._stub.Upsert.future(request, timeout=timeout) + if kwargs.get("_async", False) is True: + cb = kwargs.get("_callback", None) + f = MutationFuture(rf, cb, timeout=timeout, **kwargs) + f.add_callback(ts_utils.update_ts_on_mutation(collection_name)) + return f + + response = rf.result() + if response.status.error_code == 0: + m = MutationResult(response) + ts_utils.update_collection_ts(collection_name, m.timestamp) + return m + + raise MilvusException( + response.status.error_code, response.status.reason) + except Exception as err: + if kwargs.get("_async", False): + return MutationFuture(None, None, err) + raise err + def _execute_search_requests(self, requests, timeout=None, **kwargs): auto_id = kwargs.get("auto_id", True) diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index e08920e95..c401762ed 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -8,6 +8,7 @@ from . import entity_helper from .check import check_pass_param, is_legal_collection_properties from .types import DataType, PlaceholderType, get_consistency_level +from .utils import traverse_info from .constants import DEFAULT_CONSISTENCY_LEVEL from ..exceptions import ParamError, DataNotMatchException, ExceptionsMessage from ..orm.schema import CollectionSchema @@ -247,58 +248,23 @@ def partition_name(cls, collection_name, partition_name): tag=partition_name) @classmethod - def batch_insert_param(cls, collection_name, entities, partition_name, fields_info=None, **kwargs): - # insert_request.hash_keys won't be filled in client. It will be filled in proxy. + def batch_insert_or_upsert_param(cls, collection_name, entities, partition_name, fields_info=None, isInsert=True, **kwargs): + # insert_request.hash_keys and upsert_request.hash_keys won't be filled in client. It will be filled in proxy. - tag = partition_name or "_default" # should here? - insert_request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) + tag = partition_name or "_default" # should here? + request = milvus_types.InsertRequest(collection_name=collection_name, partition_name=tag) + + if not isInsert: + request = milvus_types.UpsertRequest(collection_name=collection_name, partition_name=tag) for entity in entities: if not entity.get("name", None) or not entity.get("values", None) or not entity.get("type", None): raise ParamError(message="Missing param in entities, a field must have type, name and values") - if not fields_info: - raise ParamError(message="Missing collection meta to validate entities") + if not fields_info: + raise ParamError(message="Missing collection meta to validate entities") location, primary_key_loc, auto_id_loc = {}, None, None - for i, field in enumerate(fields_info): - if field.get("is_primary", False): - primary_key_loc = i - - if field.get("auto_id", False): - auto_id_loc = i - continue - - match_flag = False - field_name = field["name"] - field_type = field["type"] - - for j, entity in enumerate(entities): - entity_name, entity_type = entity["name"], entity["type"] - - if field_name == entity_name: - if field_type != entity_type: - raise ParamError(message=f"Collection field type is {field_type}" - f", but entities field type is {entity_type}") - - entity_dim, field_dim = 0, 0 - if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: - field_dim = field["params"]["dim"] - entity_dim = len(entity["values"][0]) - - if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim}") - - if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim: - raise ParamError(message=f"Collection field dim is {field_dim}" - f", but entities field dim is {entity_dim * 8}") - - location[field["name"]] = j - match_flag = True - break - - if not match_flag: - raise ParamError(message=f"Field {field['name']} don't match in entities") + traverse_info(fields_info, entities, location, primary_key_loc, auto_id_loc) # though impossible from sdk if primary_key_loc is None: @@ -318,13 +284,13 @@ def batch_insert_param(cls, collection_name, entities, partition_name, fields_in raise ParamError(message="row num misaligned current[{current}]!= previous[{row_num}]") row_num = current field_data = entity_helper.entity_to_field_data(entity, fields_info[location[entity.get("name")]]) - insert_request.fields_data.append(field_data) + request.fields_data.append(field_data) except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e - insert_request.num_rows = row_num + request.num_rows = row_num - return insert_request + return request @classmethod def delete_request(cls, collection_name, partition_name, expr): diff --git a/pymilvus/client/utils.py b/pymilvus/client/utils.py index 477702be9..fbd1c3828 100644 --- a/pymilvus/client/utils.py +++ b/pymilvus/client/utils.py @@ -2,7 +2,7 @@ from .types import DataType from .constants import LOGICAL_BITS, LOGICAL_BITS_MASK -from ..exceptions import MilvusException +from ..exceptions import ParamError, MilvusException valid_index_types = [ "FLAT", @@ -149,3 +149,46 @@ def len_of(field_data) -> int: return int(total_len / (dim / 8)) raise MilvusException(message="Unknown data type") + + +def traverse_info(fields_info, entities, location, primary_key_loc, auto_id_loc): + for i, field in enumerate(fields_info): + if field.get("is_primary", False): + primary_key_loc = i + + if field.get("auto_id", False): + auto_id_loc = i + continue + + match_flag = False + field_name = field["name"] + field_type = field["type"] + + for j, entity in enumerate(entities): + entity_name, entity_type = entity["name"], entity["type"] + + if field_name == entity_name: + if field_type != entity_type: + raise ParamError(message=f"Collection field type is {field_type}" + f", but entities field type is {entity_type}") + + entity_dim, field_dim = 0, 0 + if entity_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: + field_dim = field["params"]["dim"] + entity_dim = len(entity["values"][0]) + + if entity_type in [DataType.FLOAT_VECTOR, ] and entity_dim != field_dim: + raise ParamError(message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim}") + + if entity_type in [DataType.BINARY_VECTOR, ] and entity_dim * 8 != field_dim: + raise ParamError(message=f"Collection field dim is {field_dim}" + f", but entities field dim is {entity_dim * 8}") + + location[field["name"]] = j + match_flag = True + break + + if not match_flag: + raise ParamError( + message=f"Field {field['name']} don't match in entities") diff --git a/pymilvus/exceptions.py b/pymilvus/exceptions.py index 0f1cfc1e0..ccf3715ec 100644 --- a/pymilvus/exceptions.py +++ b/pymilvus/exceptions.py @@ -120,6 +120,10 @@ class InvalidConsistencyLevel(MilvusException): """ Raise when consistency level is invalid """ +class UpsertAutoIDTrueException(MilvusException): + """ Raise when upsert autoID is true """ + + class ExceptionsMessage: NoHostPort = "connection configuration must contain 'host' and 'port'." HostType = "Type of 'host' must be str." @@ -165,3 +169,4 @@ class ExceptionsMessage: ExprType = "The type of expr must be string ,but %r is given." EnvConfigErr = "Environment variable %s has a wrong format, please check it: %s" AmbiguousIndexName = "There are multiple indexes, please specify the index_name." + UpsertAutoIDTrue = "Upsert don't support autoid == true" diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 288a7347c..ed159b6ab 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -12,7 +12,7 @@ import copy import json -from typing import List +from typing import List, Union import pandas from .connections import connections @@ -20,7 +20,7 @@ CollectionSchema, FieldSchema, parse_fields_from_data, - check_insert_data_schema, + check_insert_or_upsert_data_schema, check_schema, ) from .prepare import Prepare @@ -388,7 +388,7 @@ def release(self, timeout=None, **kwargs): conn = self._get_connection() conn.release_collection(self._name, timeout=timeout, **kwargs) - def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: + def insert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: """ Insert data into the collection. Args: @@ -423,8 +423,8 @@ def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeo """ if data is None: return MutationResult(data) - check_insert_data_schema(self._schema, data) - entities = Prepare.prepare_insert_data(data, self._schema) + check_insert_or_upsert_data_schema(self._schema, data) + entities = Prepare.prepare_insert_or_upsert_data(data, self._schema) conn = self._get_connection() res = conn.batch_insert(self._name, entities, partition_name, @@ -477,6 +477,52 @@ def delete(self, expr, partition_name=None, timeout=None, **kwargs): return MutationFuture(res) return MutationResult(res) + def upsert(self, data: Union[List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: + """ Upsert data into the collection. + + Args: + data (``list/tuple/pandas.DataFrame``): The specified data to upsert + partition_name (``str``): The partition name which the data will be upserted at, + if partition name is not passed, then the data will be upserted in "_default" partition + timeout (``float``, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + Returns: + MutationResult: contains 2 properties `upsert_count`, and, `primary_keys` + `upsert_count`: how may entites have been upserted at Milvus, + `primary_keys`: list of primary keys of the upserted entities + Raises: + MilvusException: If anything goes wrong. + + Examples: + >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType + >>> import random + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_collection_upsert", schema) + >>> data = [ + ... [random.randint(1, 100) for _ in range(10)], + ... [[random.random() for _ in range(2)] for _ in range(10)], + ... ] + >>> res = collection.upsert(data) + >>> res.upsert_count + 10 + """ + if data is None: + return MutationResult(data) + check_insert_or_upsert_data_schema(self._schema, data, False) + entities = Prepare.prepare_insert_or_upsert_data(data, self._schema, False) + + conn = self._get_connection() + res = conn.upsert(self._name, entities, partition_name, + timeout=timeout, schema=self._schema_dict, **kwargs) + + if kwargs.get("_async", False): + return MutationFuture(res) + return MutationResult(res) + def search(self, data, anns_field, param, limit, expr=None, partition_names=None, output_fields=None, timeout=None, round_decimal=-1, **kwargs): """ Conducts a vector similarity search with an optional boolean expression as filter. diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 73a3cff3e..586346b95 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -296,7 +296,7 @@ def insert(self, data, timeout=None, **kwargs): if conn.has_partition(self._collection.name, self._name, **kwargs) is False: raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) # TODO: check insert data schema here? - entities = Prepare.prepare_insert_data(data, self._collection.schema) + entities = Prepare.prepare_insert_or_upsert_data(data, self._collection.schema) res = conn.batch_insert(self._collection.name, entities=entities, partition_name=self._name, timeout=timeout, orm=True, schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): @@ -348,6 +348,59 @@ def delete(self, expr, timeout=None, **kwargs): return MutationFuture(res) return MutationResult(res) + def upsert(self, data, timeout=None, **kwargs): + """ + Upsert data into partition. + + :param data: The specified data to upsert, the dimension of data needs to align with column + number + :type data: list-like(list, tuple) object or pandas.DataFrame + + :param timeout: An optional duration of time in seconds to allow for the RPC. When timeout + is set to None, client waits until server response or error occur + :type timeout: float + + :param kwargs: + * *timeout* (``float``) -- + An optional duration of time in seconds to allow for the RPC. When timeout + is set to None, client waits until server response or error occur. + + :return: A MutationResult object contains a property named `upsert_count` represents how many + entities have been upserted at milvus and a property named `primary_keys` is a list of primary + keys of the upserted entities. + :rtype: MutationResult + + :raises PartitionNotExistException: + When partitoin does not exist + + :example: + >>> from pymilvus import connections, Collection, Partition, FieldSchema, CollectionSchema, DataType + >>> connections.connect() + >>> schema = CollectionSchema([ + ... FieldSchema("film_id", DataType.INT64, is_primary=True), + ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) + ... ]) + >>> collection = Collection("test_partition_upsert", schema) + >>> partition = Partition(collection, "comedy", "comedy films") + >>> data = [ + ... [i for i in range(10)], + ... [[float(i) for i in range(2)] for _ in range(10)], + ... ] + >>> partition.upsert(data) + >>> partition.num_entities + 10 + """ + conn = self._get_connection() + if conn.has_partition(self._collection.name, self._name, **kwargs) is False: + raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) + # TODO: check upsert data schema here? + entities = Prepare.prepare_insert_or_upsert_data(data, self._collection.schema,False) + res = conn.upsert(self._collection.name, entities=entities, partition_name=self._name, + timeout=timeout, orm=True, schema=self._schema_dict, **kwargs) + if kwargs.get("_async", False): + return MutationFuture(res) + return MutationResult(res) + def search(self, data, anns_field, param, limit, expr=None, output_fields=None, timeout=None, round_decimal=-1, **kwargs): """ Conducts a vector similarity search with an optional boolean expression as filter. diff --git a/pymilvus/orm/prepare.py b/pymilvus/orm/prepare.py index 3fc843555..7fef129f8 100644 --- a/pymilvus/orm/prepare.py +++ b/pymilvus/orm/prepare.py @@ -14,12 +14,17 @@ import numpy import pandas -from ..exceptions import DataNotMatchException, DataTypeNotSupportException, ExceptionsMessage +from ..exceptions import ( + DataNotMatchException, + DataTypeNotSupportException, + ExceptionsMessage, + UpsertAutoIDTrueException, +) class Prepare: @classmethod - def prepare_insert_data(cls, data, schema): + def prepare_insert_or_upsert_data(cls, data, schema, isInsert = True): if not isinstance(data, (list, tuple, pandas.DataFrame)): raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) @@ -29,6 +34,8 @@ def prepare_insert_data(cls, data, schema): if isinstance(data, pandas.DataFrame): if schema.auto_id: + if isInsert is False: + raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) if schema.primary_field.name in data: if len(fields) != len(data.columns): raise DataNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 5d8b56d55..140d47313 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -11,7 +11,7 @@ # the License. import copy -from typing import List +from typing import List, Union import pandas from pandas.api.types import is_list_like @@ -27,6 +27,7 @@ ExceptionsMessage, DataNotMatchException, SchemaNotReadyException, + UpsertAutoIDTrueException ) @@ -289,26 +290,29 @@ def dtype(self): return self._dtype -def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas.DataFrame]) -> None: - """ check if the insert data is consist with the collection schema +def check_insert_or_upsert_data_schema(schema: CollectionSchema, data: Union[List[List], pandas.DataFrame], isInsert=True) -> None: + """ check if the insert or upsert data is consist with the collection schema Args: schema (CollectionSchema): the schema of the collection - data (List[List], pandas.DataFrame): the data to be inserted + data (List[List], pandas.DataFrame): the data to be inserted or upserted Raise: SchemaNotReadyException: if the schema is None + UpsertAutoIDTrueException: if autoid option is true DataNotMatchException: if the data is in consist with the schema """ if schema is None: raise SchemaNotReadyException(message="Schema shouldn't be None") if schema.auto_id: - if isinstance(data, pandas.DataFrame): - if schema.primary_field.name in data: - if not data[schema.primary_field.name].isnull().all(): - raise DataNotMatchException(message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}") - data = data.drop(schema.primary_field.name, axis=1) - + if isInsert: + if isinstance(data, pandas.DataFrame): + if schema.primary_field.name in data: + if not data[schema.primary_field.name].isnull().all(): + raise DataNotMatchException(message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}") + data = data.drop(schema.primary_field.name, axis=1) + else: + raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue) infer_fields = parse_fields_from_data(data) tmp_fields = copy.deepcopy(schema.fields) @@ -329,7 +333,7 @@ def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas raise DataNotMatchException(message=f"The name of field don't match, expected: {y.name}, got {x.name}") -def parse_fields_from_data(data: [List[List], pandas.DataFrame]) -> List[FieldSchema]: +def parse_fields_from_data(data: Union[List[List], pandas.DataFrame]) -> List[FieldSchema]: if not isinstance(data, (pandas.DataFrame, list)): raise DataTypeNotSupportException(message="The type of data should be list or pandas.DataFrame")