From b268bcd160433245c9a66c8c4ac6a43044a07279 Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Tue, 8 Nov 2022 14:10:59 +0800 Subject: [PATCH] Make error message more clear when schema mismatch (#1217) See also: milvus-io/milvus#16536, #914 Signed-off-by: XuanYang-cn Signed-off-by: XuanYang-cn --- pymilvus/orm/collection.py | 117 +++++++++---------------------------- pymilvus/orm/schema.py | 93 ++++++++++++++++++++++------- 2 files changed, 100 insertions(+), 110 deletions(-) diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index a0f4b48aa..21e8807cd 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -12,6 +12,7 @@ import copy import json +from typing import List import pandas from .connections import connections @@ -19,6 +20,8 @@ CollectionSchema, FieldSchema, parse_fields_from_data, + check_insert_data_schema, + check_schema, ) from .prepare import Prepare from .partition import Partition @@ -29,7 +32,6 @@ from ..exceptions import ( SchemaNotReadyException, DataTypeNotMatchException, - DataNotMatchException, PartitionAlreadyExistException, PartitionNotExistException, IndexNotExistException, @@ -45,18 +47,6 @@ from ..client.configs import DefaultConfigs -def _check_schema(schema): - if schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema) - if len(schema.fields) < 1: - raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema) - vector_fields = [] - for field in schema.fields: - if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): - vector_fields.append(field.name) - if len(vector_fields) < 1: - raise SchemaNotReadyException(message=ExceptionsMessage.NoVector) - class Collection: """ This is a class corresponding to collection in milvus. """ @@ -136,7 +126,7 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs): if schema is None: raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name) if isinstance(schema, CollectionSchema): - _check_schema(schema) + check_schema(schema) consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL)) conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs) self._schema = schema @@ -160,48 +150,6 @@ def __repr__(self): def _get_connection(self): return connections._fetch_handler(self._using) - def _check_insert_data_schema(self, data): - """ - Checks whether the data type matches the schema. - """ - if self._schema is None: - return False - if self._schema.auto_id: - if isinstance(data, pandas.DataFrame): - if self._schema.primary_field.name in data: - if not data[self._schema.primary_field.name].isnull().all(): - raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData) - data = data.drop(self._schema.primary_field.name, axis=1) - - infer_fields = parse_fields_from_data(data) - tmp_fields = copy.deepcopy(self._schema.fields) - - for i, field in enumerate(self._schema.fields): - if field.is_primary and field.auto_id: - tmp_fields.pop(i) - - if len(infer_fields) != len(tmp_fields): - raise DataTypeNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent) - - for x, y in zip(infer_fields, tmp_fields): - if x.dtype != y.dtype: - return False - if isinstance(data, pandas.DataFrame): - if x.name != y.name: - return False - # todo check dim - return True - - def _check_schema(self): - if self._schema is None: - raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema) - - def _get_vector_field(self) -> str: - for field in self._schema.fields: - if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): - return field.name - raise SchemaNotReadyException(message=ExceptionsMessage.NoVector) - @classmethod def construct_from_dataframe(cls, name, dataframe, **kwargs): if dataframe is None: @@ -249,7 +197,7 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs): field.params[DefaultConfigs.MaxVarCharLengthKey] = int(DefaultConfigs.MaxVarCharLength) schema = CollectionSchema(fields=fields_schema) - _check_schema(schema) + check_schema(schema) collection = cls(name, schema, **kwargs) res = collection.insert(data=dataframe) return collection, res @@ -502,37 +450,28 @@ def release(self, timeout=None, **kwargs): conn = self._get_connection() conn.release_collection(self._name, timeout=timeout, **kwargs) - def insert(self, data, partition_name=None, timeout=None, **kwargs): - """ - Insert data into the collection. - - :param data: The specified data to insert, the dimension of data needs to align with column - number - :type data: list-like(list, tuple) object or pandas.DataFrame - :param partition_name: The partition name which the data will be inserted to, if partition - name is not passed, then the data will be inserted to "_default" - partition - :type partition_name: str - - :param timeout: - * *timeout* (``float``) -- - An optional duration of time in seconds to allow for the RPC. If timeout - is set to None, the client keeps waiting until the server responds or an error occurs. - - :return: A MutationResult object contains a property named `insert_count` represents how many - entities have been inserted into milvus and a property named `primary_keys` is a list of primary - keys of the inserted entities. - :rtype: MutationResult - - :raises CollectionNotExistException: If the specified collection does not exist. - :raises ParamError: If input parameters are invalid. - :raises BaseException: If the specified partition does not exist. - - :example: + def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult: + """ Insert data into the collection. + + Args: + data (list, tuple, pandas.DataFrame): The specified data to insert + partition_name (str): The partition name which the data will be inserted to, + if partition name is not passed, then the data will be inserted to "_default" partition + timeout (float, optional): A duration of time in seconds to allow for the RPC. Defaults to None. + If timeout is set to None, the client keeps waiting until the server responds or an error occurs. + Returns: + MutationResult: contains 2 properties `insert_count`, and, `primary_keys` + `insert_count`: how may entites have been inserted into Milvus, + `primary_keys`: list of primary keys of the inserted entities + Raises: + CollectionNotExistException: If the specified collection does not exist. + ParamError: If input parameters are invalid. + MilvusException: If the specified partition does not exist. + + Examples: >>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType >>> import random >>> connections.connect() - >>> schema = CollectionSchema([ ... FieldSchema("film_id", DataType.INT64, is_primary=True), ... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2) @@ -542,14 +481,14 @@ def insert(self, data, partition_name=None, timeout=None, **kwargs): ... [random.randint(1, 100) for _ in range(10)], ... [[random.random() for _ in range(2)] for _ in range(10)], ... ] - >>> collection.insert(data) - >>> collection.num_entities + >>> res = collection.insert(data) + >>> res.insert_count 10 """ if data is None: return MutationResult(data) - if not self._check_insert_data_schema(data): - raise SchemaNotReadyException(message=ExceptionsMessage.TypeOfDataAndSchemaInconsistent) + check_insert_data_schema(self._schema, data) + conn = self._get_connection() entities = Prepare.prepare_insert_data(data, self._schema) schema_dict = self._schema.to_dict() diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index de0e186f1..5d8b56d55 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -24,7 +24,9 @@ FieldsTypeException, FieldTypeException, AutoIDException, - ExceptionsMessage + ExceptionsMessage, + DataNotMatchException, + SchemaNotReadyException, ) @@ -287,33 +289,72 @@ def dtype(self): return self._dtype -def parse_fields_from_data(datas): - if isinstance(datas, pandas.DataFrame): - return parse_fields_from_dataframe(datas) - fields = [] - if not isinstance(datas, list): - raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) - for d in datas: +def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas.DataFrame]) -> None: + """ check if the insert data is consist with the collection schema + + Args: + schema (CollectionSchema): the schema of the collection + data (List[List], pandas.DataFrame): the data to be inserted + + Raise: + SchemaNotReadyException: if the schema is None + DataNotMatchException: if the data is in consist with the schema + """ + if schema is None: + raise SchemaNotReadyException(message="Schema shouldn't be None") + if schema.auto_id: + if isinstance(data, pandas.DataFrame): + if schema.primary_field.name in data: + if not data[schema.primary_field.name].isnull().all(): + raise DataNotMatchException(message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}") + data = data.drop(schema.primary_field.name, axis=1) + + infer_fields = parse_fields_from_data(data) + tmp_fields = copy.deepcopy(schema.fields) + + for i, field in enumerate(schema.fields): + if field.is_primary and field.auto_id: + tmp_fields.pop(i) + + if len(infer_fields) != len(tmp_fields): + + i_name = [f.name for f in infer_fields] + t_name = [f.name for f in tmp_fields] + raise DataNotMatchException(message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}") + + for x, y in zip(infer_fields, tmp_fields): + if x.dtype != y.dtype: + raise DataNotMatchException(message=f"The data type of field {y.name} doesn't match, expected: {y.dtype.name}, got {x.dtype.name}") + if isinstance(data, pandas.DataFrame) and x.name != y.name: + raise DataNotMatchException(message=f"The name of field don't match, expected: {y.name}, got {x.name}") + + +def parse_fields_from_data(data: [List[List], pandas.DataFrame]) -> List[FieldSchema]: + if not isinstance(data, (pandas.DataFrame, list)): + raise DataTypeNotSupportException(message="The type of data should be list or pandas.DataFrame") + + if isinstance(data, pandas.DataFrame): + return parse_fields_from_dataframe(data) + + for d in data: if not is_list_like(d): - raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport) - d_type = infer_dtype_bydata(d[0]) - fields.append(FieldSchema("", d_type)) + raise DataTypeNotSupportException(message="data should be a list of list") + + fields = [FieldSchema("", infer_dtype_bydata(d[0])) for d in data] return fields -def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]: - if not isinstance(dataframe, pandas.DataFrame): - return None - d_types = list(dataframe.dtypes) +def parse_fields_from_dataframe(df: pandas.DataFrame) -> List[FieldSchema]: + d_types = list(df.dtypes) data_types = list(map(map_numpy_dtype_to_datatype, d_types)) - col_names = list(dataframe.columns) + col_names = list(df.columns) column_params_map = {} if DataType.UNKNOWN in data_types: - if len(dataframe) == 0: + if len(df) == 0: raise CannotInferSchemaException(message=ExceptionsMessage.DataFrameInvalid) - values = dataframe.head(1).values[0] + values = df.head(1).values[0] for i, dtype in enumerate(data_types): if dtype == DataType.UNKNOWN: new_dtype = infer_dtype_bydata(values[i]) @@ -324,9 +365,6 @@ def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]: else: vector_type_params['dim'] = len(values[i]) column_params_map[col_names[i]] = vector_type_params - # if new_dtype in (DataType.VARCHAR,): - # str_type_params = {} - # str_type_params[DefaultConfigs.MaxVarCharLengthKey] = DefaultConfigs.MaxVarCharLength data_types[i] = new_dtype if DataType.UNKNOWN in data_types: @@ -339,3 +377,16 @@ def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]: fields.append(field_schema) return fields + + +def check_schema(schema): + if schema is None: + raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema) + if len(schema.fields) < 1: + raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema) + vector_fields = [] + for field in schema.fields: + if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR): + vector_fields.append(field.name) + if len(vector_fields) < 1: + raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)