Skip to content

Commit

Permalink
Refine the error message for type mismatches during data insertion
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Zhang <[email protected]>
  • Loading branch information
xiaocai2333 committed Jun 27, 2024
1 parent 0318663 commit b327323
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 82 deletions.
316 changes: 236 additions & 80 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,84 +242,159 @@ def pack_field_value_to_field_data(
field_value: Any, field_data: schema_types.FieldData, field_info: Any
):
field_type = field_data.type
field_name = field_info["name"]
if field_type == DataType.BOOL:
field_data.scalars.bool_data.data.append(field_value)
try:
field_data.scalars.bool_data.data.append(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "bool", type(field_value))
) from e
elif field_type in (DataType.INT8, DataType.INT16, DataType.INT32):
field_data.scalars.int_data.data.append(field_value)
try:
field_data.scalars.int_data.data.append(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int", type(field_value))
) from e
elif field_type == DataType.INT64:
field_data.scalars.long_data.data.append(field_value)
try:
field_data.scalars.long_data.data.append(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int64", type(field_value))
) from e
elif field_type == DataType.FLOAT:
field_data.scalars.float_data.data.append(field_value)
try:
field_data.scalars.float_data.data.append(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float", type(field_value))
) from e
elif field_type == DataType.DOUBLE:
field_data.scalars.double_data.data.append(field_value)
try:
field_data.scalars.double_data.data.append(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "double", type(field_value))
) from e
elif field_type == DataType.FLOAT_VECTOR:
f_value = field_value
if isinstance(field_value, np.ndarray):
if field_value.dtype not in ("float32", "float64"):
raise ParamError(
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
)
f_value = field_value.tolist()

field_data.vectors.dim = len(f_value)
field_data.vectors.float_vector.data.extend(f_value)

try:
f_value = field_value
if isinstance(field_value, np.ndarray):
if field_value.dtype not in ("float32", "float64"):
raise ParamError(
message="invalid input for float32 vector, expect np.ndarray with dtype=float32"
)
f_value = field_value.tolist()

field_data.vectors.dim = len(f_value)
field_data.vectors.float_vector.data.extend(f_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float_vector", type(field_value))
) from e
elif field_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)

try:
field_data.vectors.dim = len(field_value) * 8
field_data.vectors.binary_vector += bytes(field_value)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "binary_vector", type(field_value))
) from e
elif field_type == DataType.FLOAT16_VECTOR:
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "float16":
try:
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "float16":
raise ParamError(
message="invalid input for float16 vector, expect np.ndarray with dtype=float16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input for float16 vector, expect np.ndarray with dtype=float16"
message="invalid input type for float16 vector, expect np.ndarray with dtype=float16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for float16 vector, expect np.ndarray with dtype=float16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.float16_vector += v_bytes

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.float16_vector += v_bytes
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float16_vector", type(field_value))
) from e
elif field_type == DataType.BFLOAT16_VECTOR:
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "bfloat16":
try:
if isinstance(field_value, bytes):
v_bytes = field_value
elif isinstance(field_value, np.ndarray):
if field_value.dtype != "bfloat16":
raise ParamError(
message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)
v_bytes = field_value.view(np.uint8).tobytes()
else:
raise ParamError(
message="invalid input type for bfloat16 vector, expect np.ndarray with dtype=bfloat16"
)

field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.bfloat16_vector += v_bytes
field_data.vectors.dim = len(v_bytes) // 2
field_data.vectors.bfloat16_vector += v_bytes
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "bfloat16_vector", type(field_value))
) from e
elif field_type == DataType.SPARSE_FLOAT_VECTOR:
# field_value is a single row of sparse float vector in user provided format
if not SciPyHelper.is_scipy_sparse(field_value):
field_value = [field_value]
elif field_value.shape[0] != 1:
raise ParamError(message="invalid input for sparse float vector: expect 1 row")
if not entity_is_sparse_matrix(field_value):
raise ParamError(message="invalid input for sparse float vector")
field_data.vectors.sparse_float_vector.contents.append(
sparse_rows_to_proto(field_value).contents[0]
)
try:
if not SciPyHelper.is_scipy_sparse(field_value):
field_value = [field_value]
elif field_value.shape[0] != 1:
raise ParamError(message="invalid input for sparse float vector: expect 1 row")
if not entity_is_sparse_matrix(field_value):
raise ParamError(message="invalid input for sparse float vector")
field_data.vectors.sparse_float_vector.contents.append(
sparse_rows_to_proto(field_value).contents[0]
)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "sparse_float_vector", type(field_value))
) from e
elif field_type == DataType.VARCHAR:
field_data.scalars.string_data.data.append(
convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY)
)
try:
field_data.scalars.string_data.data.append(
convert_to_str_array(field_value, field_info, CHECK_STR_ARRAY)
)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "varchar", type(field_value))
) from e
elif field_type == DataType.JSON:
field_data.scalars.json_data.data.append(convert_to_json(field_value))
try:
field_data.scalars.json_data.data.append(convert_to_json(field_value))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "json", type(field_value))
) from e
elif field_type == DataType.ARRAY:
field_data.scalars.array_data.data.append(convert_to_array(field_value, field_info))
try:
field_data.scalars.array_data.data.append(convert_to_array(field_value, field_info))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "array", type(field_value))
) from e
else:
raise ParamError(message=f"UnSupported data type: {field_type}")

Expand All @@ -329,42 +404,123 @@ def entity_to_field_data(entity: Any, field_info: Any):
field_data = schema_types.FieldData()

entity_type = entity.get("type")
field_data.field_name = entity.get("name")
field_name = entity.get("name")
field_data.field_name = field_name
field_data.type = entity_type_to_dtype(entity_type)

if entity_type == DataType.BOOL:
field_data.scalars.bool_data.data.extend(entity.get("values"))
try:
field_data.scalars.bool_data.data.extend(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "bool", type(entity.get("values")[0]))
) from e
elif entity_type in (DataType.INT8, DataType.INT16, DataType.INT32):
field_data.scalars.int_data.data.extend(entity.get("values"))
try:
field_data.scalars.int_data.data.extend(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.INT64:
field_data.scalars.long_data.data.extend(entity.get("values"))
try:
field_data.scalars.long_data.data.extend(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int64", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.FLOAT:
field_data.scalars.float_data.data.extend(entity.get("values"))
try:
field_data.scalars.float_data.data.extend(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.DOUBLE:
field_data.scalars.double_data.data.extend(entity.get("values"))
try:
field_data.scalars.double_data.data.extend(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "double", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.FLOAT_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0])
all_floats = [f for vector in entity.get("values") for f in vector]
field_data.vectors.float_vector.data.extend(all_floats)
try:
field_data.vectors.dim = len(entity.get("values")[0])
all_floats = [f for vector in entity.get("values") for f in vector]
field_data.vectors.float_vector.data.extend(all_floats)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float_vector", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.BINARY_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) * 8
field_data.vectors.binary_vector = b"".join(entity.get("values"))
try:
field_data.vectors.dim = len(entity.get("values")[0]) * 8
field_data.vectors.binary_vector = b"".join(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "binary_vector", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.FLOAT16_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.float16_vector = b"".join(entity.get("values"))
try:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.float16_vector = b"".join(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "float16_vector", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.BFLOAT16_VECTOR:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.bfloat16_vector = b"".join(entity.get("values"))
try:
field_data.vectors.dim = len(entity.get("values")[0]) // 2
field_data.vectors.bfloat16_vector = b"".join(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "bfloat16_vector", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.VARCHAR:
field_data.scalars.string_data.data.extend(
entity_to_str_arr(entity, field_info, CHECK_STR_ARRAY)
)
try:
field_data.scalars.string_data.data.extend(
entity_to_str_arr(entity, field_info, CHECK_STR_ARRAY)
)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "varchar", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.JSON:
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
try:
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "json", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.ARRAY:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
try:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "array", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.SPARSE_FLOAT_VECTOR:
field_data.vectors.sparse_float_vector.CopyFrom(sparse_rows_to_proto(entity.get("values")))
try:
field_data.vectors.sparse_float_vector.CopyFrom(
sparse_rows_to_proto(entity.get("values"))
)
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "sparse_float_vector", type(entity.get("values")[0]))
) from e
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")

Expand Down
13 changes: 12 additions & 1 deletion pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ def _parse_row_request(
field["name"]: field for field in fields_info if not field.get("auto_id", False)
}

pk_field_name = ""
auto_id = False
for field in fields_info:
if field.get("auto_id", True):
pk_field_name = field["name"]
auto_id = True

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
fields_data[d_field.field_name] = d_field
Expand All @@ -381,7 +388,11 @@ def _parse_row_request(
raise TypeError(msg)
for k, v in entity.items():
if k not in fields_data and not enable_dynamic:
raise DataNotMatchException(message=ExceptionsMessage.InsertUnexpectedField)
raise DataNotMatchException(
message=ExceptionsMessage.InsertUnexpectedField % k
)
if k == pk_field_name and auto_id:
raise DataNotMatchException(message=ExceptionsMessage.AutoIDWithData)

if k in fields_data:
field_info, field_data = field_info_map[k], fields_data[k]
Expand Down
Loading

0 comments on commit b327323

Please sign in to comment.