Skip to content

Commit

Permalink
Make error message more clear when schema mismatch (#1217)
Browse files Browse the repository at this point in the history
See also: milvus-io/milvus#16536, #914

Signed-off-by: XuanYang-cn <[email protected]>

Signed-off-by: XuanYang-cn <[email protected]>
  • Loading branch information
XuanYang-cn authored Nov 8, 2022
1 parent 7be2a67 commit b268bcd
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 110 deletions.
117 changes: 28 additions & 89 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@

import copy
import json
from typing import List
import pandas

from .connections import connections
from .schema import (
CollectionSchema,
FieldSchema,
parse_fields_from_data,
check_insert_data_schema,
check_schema,
)
from .prepare import Prepare
from .partition import Partition
Expand All @@ -29,7 +32,6 @@
from ..exceptions import (
SchemaNotReadyException,
DataTypeNotMatchException,
DataNotMatchException,
PartitionAlreadyExistException,
PartitionNotExistException,
IndexNotExistException,
Expand All @@ -45,18 +47,6 @@
from ..client.configs import DefaultConfigs


def _check_schema(schema):
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema)
if len(schema.fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema)
vector_fields = []
for field in schema.fields:
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
vector_fields.append(field.name)
if len(vector_fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)


class Collection:
""" This is a class corresponding to collection in milvus. """
Expand Down Expand Up @@ -136,7 +126,7 @@ def __init__(self, name, schema=None, using="default", shards_num=2, **kwargs):
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.CollectionNotExistNoSchema % name)
if isinstance(schema, CollectionSchema):
_check_schema(schema)
check_schema(schema)
consistency_level = get_consistency_level(kwargs.get("consistency_level", DEFAULT_CONSISTENCY_LEVEL))
conn.create_collection(self._name, schema, shards_num=self._shards_num, **kwargs)
self._schema = schema
Expand All @@ -160,48 +150,6 @@ def __repr__(self):
def _get_connection(self):
return connections._fetch_handler(self._using)

def _check_insert_data_schema(self, data):
"""
Checks whether the data type matches the schema.
"""
if self._schema is None:
return False
if self._schema.auto_id:
if isinstance(data, pandas.DataFrame):
if self._schema.primary_field.name in data:
if not data[self._schema.primary_field.name].isnull().all():
raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData)
data = data.drop(self._schema.primary_field.name, axis=1)

infer_fields = parse_fields_from_data(data)
tmp_fields = copy.deepcopy(self._schema.fields)

for i, field in enumerate(self._schema.fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

if len(infer_fields) != len(tmp_fields):
raise DataTypeNotMatchException(message=ExceptionsMessage.FieldsNumInconsistent)

for x, y in zip(infer_fields, tmp_fields):
if x.dtype != y.dtype:
return False
if isinstance(data, pandas.DataFrame):
if x.name != y.name:
return False
# todo check dim
return True

def _check_schema(self):
if self._schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema)

def _get_vector_field(self) -> str:
for field in self._schema.fields:
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
return field.name
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)

@classmethod
def construct_from_dataframe(cls, name, dataframe, **kwargs):
if dataframe is None:
Expand Down Expand Up @@ -249,7 +197,7 @@ def construct_from_dataframe(cls, name, dataframe, **kwargs):
field.params[DefaultConfigs.MaxVarCharLengthKey] = int(DefaultConfigs.MaxVarCharLength)
schema = CollectionSchema(fields=fields_schema)

_check_schema(schema)
check_schema(schema)
collection = cls(name, schema, **kwargs)
res = collection.insert(data=dataframe)
return collection, res
Expand Down Expand Up @@ -502,37 +450,28 @@ def release(self, timeout=None, **kwargs):
conn = self._get_connection()
conn.release_collection(self._name, timeout=timeout, **kwargs)

def insert(self, data, partition_name=None, timeout=None, **kwargs):
"""
Insert data into the collection.
:param data: The specified data to insert, the dimension of data needs to align with column
number
:type data: list-like(list, tuple) object or pandas.DataFrame
:param partition_name: The partition name which the data will be inserted to, if partition
name is not passed, then the data will be inserted to "_default"
partition
:type partition_name: str
:param timeout:
* *timeout* (``float``) --
An optional duration of time in seconds to allow for the RPC. If timeout
is set to None, the client keeps waiting until the server responds or an error occurs.
:return: A MutationResult object contains a property named `insert_count` represents how many
entities have been inserted into milvus and a property named `primary_keys` is a list of primary
keys of the inserted entities.
:rtype: MutationResult
:raises CollectionNotExistException: If the specified collection does not exist.
:raises ParamError: If input parameters are invalid.
:raises BaseException: If the specified partition does not exist.
:example:
def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeout=None, **kwargs) -> MutationResult:
""" Insert data into the collection.
Args:
data (list, tuple, pandas.DataFrame): The specified data to insert
partition_name (str): The partition name which the data will be inserted to,
if partition name is not passed, then the data will be inserted to "_default" partition
timeout (float, optional): A duration of time in seconds to allow for the RPC. Defaults to None.
If timeout is set to None, the client keeps waiting until the server responds or an error occurs.
Returns:
MutationResult: contains 2 properties `insert_count`, and, `primary_keys`
`insert_count`: how may entites have been inserted into Milvus,
`primary_keys`: list of primary keys of the inserted entities
Raises:
CollectionNotExistException: If the specified collection does not exist.
ParamError: If input parameters are invalid.
MilvusException: If the specified partition does not exist.
Examples:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
<pymilvus.client.stub.Milvus object at 0x7f8579002dc0>
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
Expand All @@ -542,14 +481,14 @@ def insert(self, data, partition_name=None, timeout=None, **kwargs):
... [random.randint(1, 100) for _ in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> collection.insert(data)
>>> collection.num_entities
>>> res = collection.insert(data)
>>> res.insert_count
10
"""
if data is None:
return MutationResult(data)
if not self._check_insert_data_schema(data):
raise SchemaNotReadyException(message=ExceptionsMessage.TypeOfDataAndSchemaInconsistent)
check_insert_data_schema(self._schema, data)

conn = self._get_connection()
entities = Prepare.prepare_insert_data(data, self._schema)
schema_dict = self._schema.to_dict()
Expand Down
93 changes: 72 additions & 21 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
FieldsTypeException,
FieldTypeException,
AutoIDException,
ExceptionsMessage
ExceptionsMessage,
DataNotMatchException,
SchemaNotReadyException,
)


Expand Down Expand Up @@ -287,33 +289,72 @@ def dtype(self):
return self._dtype


def parse_fields_from_data(datas):
if isinstance(datas, pandas.DataFrame):
return parse_fields_from_dataframe(datas)
fields = []
if not isinstance(datas, list):
raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport)
for d in datas:
def check_insert_data_schema(schema: CollectionSchema, data: [List[List], pandas.DataFrame]) -> None:
""" check if the insert data is consist with the collection schema
Args:
schema (CollectionSchema): the schema of the collection
data (List[List], pandas.DataFrame): the data to be inserted
Raise:
SchemaNotReadyException: if the schema is None
DataNotMatchException: if the data is in consist with the schema
"""
if schema is None:
raise SchemaNotReadyException(message="Schema shouldn't be None")
if schema.auto_id:
if isinstance(data, pandas.DataFrame):
if schema.primary_field.name in data:
if not data[schema.primary_field.name].isnull().all():
raise DataNotMatchException(message=f"Please don't provide data for auto_id primary field: {schema.primary_field.name}")
data = data.drop(schema.primary_field.name, axis=1)

infer_fields = parse_fields_from_data(data)
tmp_fields = copy.deepcopy(schema.fields)

for i, field in enumerate(schema.fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

if len(infer_fields) != len(tmp_fields):

i_name = [f.name for f in infer_fields]
t_name = [f.name for f in tmp_fields]
raise DataNotMatchException(message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}")

for x, y in zip(infer_fields, tmp_fields):
if x.dtype != y.dtype:
raise DataNotMatchException(message=f"The data type of field {y.name} doesn't match, expected: {y.dtype.name}, got {x.dtype.name}")
if isinstance(data, pandas.DataFrame) and x.name != y.name:
raise DataNotMatchException(message=f"The name of field don't match, expected: {y.name}, got {x.name}")


def parse_fields_from_data(data: [List[List], pandas.DataFrame]) -> List[FieldSchema]:
if not isinstance(data, (pandas.DataFrame, list)):
raise DataTypeNotSupportException(message="The type of data should be list or pandas.DataFrame")

if isinstance(data, pandas.DataFrame):
return parse_fields_from_dataframe(data)

for d in data:
if not is_list_like(d):
raise DataTypeNotSupportException(message=ExceptionsMessage.DataTypeNotSupport)
d_type = infer_dtype_bydata(d[0])
fields.append(FieldSchema("", d_type))
raise DataTypeNotSupportException(message="data should be a list of list")

fields = [FieldSchema("", infer_dtype_bydata(d[0])) for d in data]
return fields


def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]:
if not isinstance(dataframe, pandas.DataFrame):
return None
d_types = list(dataframe.dtypes)
def parse_fields_from_dataframe(df: pandas.DataFrame) -> List[FieldSchema]:
d_types = list(df.dtypes)
data_types = list(map(map_numpy_dtype_to_datatype, d_types))
col_names = list(dataframe.columns)
col_names = list(df.columns)

column_params_map = {}

if DataType.UNKNOWN in data_types:
if len(dataframe) == 0:
if len(df) == 0:
raise CannotInferSchemaException(message=ExceptionsMessage.DataFrameInvalid)
values = dataframe.head(1).values[0]
values = df.head(1).values[0]
for i, dtype in enumerate(data_types):
if dtype == DataType.UNKNOWN:
new_dtype = infer_dtype_bydata(values[i])
Expand All @@ -324,9 +365,6 @@ def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]:
else:
vector_type_params['dim'] = len(values[i])
column_params_map[col_names[i]] = vector_type_params
# if new_dtype in (DataType.VARCHAR,):
# str_type_params = {}
# str_type_params[DefaultConfigs.MaxVarCharLengthKey] = DefaultConfigs.MaxVarCharLength
data_types[i] = new_dtype

if DataType.UNKNOWN in data_types:
Expand All @@ -339,3 +377,16 @@ def parse_fields_from_dataframe(dataframe) -> List[FieldSchema]:
fields.append(field_schema)

return fields


def check_schema(schema):
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema)
if len(schema.fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.EmptySchema)
vector_fields = []
for field in schema.fields:
if field.dtype in (DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR):
vector_fields.append(field.name)
if len(vector_fields) < 1:
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)

0 comments on commit b268bcd

Please sign in to comment.