Skip to content

Commit

Permalink
Remove redundant type inference when insert data
Browse files Browse the repository at this point in the history
Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 committed Oct 18, 2023
1 parent 9ad3541 commit 8dd7c24
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 113 deletions.
79 changes: 79 additions & 0 deletions examples/hello_milvus_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pymilvus import CollectionSchema, FieldSchema, Collection, connections, DataType, Partition, utility
import numpy as np
import random
import pandas as pd
connections.connect()

dim = 128
collection_name = "test_array"
arr_len = 100
nb = 10
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
# create collection
pk_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, description='pk')
vector_field = FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
int8_array = FieldSchema(name="int8_array", dtype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=arr_len)
int16_array = FieldSchema(name="int16_array", dtype=DataType.ARRAY, element_type=DataType.INT16, max_capacity=arr_len)
int32_array = FieldSchema(name="int32_array", dtype=DataType.ARRAY, element_type=DataType.INT32, max_capacity=arr_len)
int64_array = FieldSchema(name="int64_array", dtype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=arr_len)
bool_array = FieldSchema(name="bool_array", dtype=DataType.ARRAY, element_type=DataType.BOOL, max_capacity=arr_len)
float_array = FieldSchema(name="float_array", dtype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=arr_len)
double_array = FieldSchema(name="double_array", dtype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=arr_len)
string_array = FieldSchema(name="string_array", dtype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=arr_len,
max_length=100)

fields = [pk_field, vector_field, int8_array, int16_array, int32_array, int64_array,
bool_array, float_array, double_array, string_array]

schema = CollectionSchema(fields=fields)
collection = Collection(collection_name, schema=schema)

# insert data
pk_value = [i for i in range(nb)]
vector_value = [[random.random() for _ in range(dim)] for i in range(nb)]
int8_value = [[np.int8(j) for j in range(arr_len)] for i in range(nb)]
int16_value = [[np.int16(j) for j in range(arr_len)] for i in range(nb)]
int32_value = [[np.int32(j) for j in range(arr_len)] for i in range(nb)]
int64_value = [[np.int64(j) for j in range(arr_len)] for i in range(nb)]
bool_value = [[np.bool_(j) for j in range(arr_len)] for i in range(nb)]
float_value = [[np.float32(j) for j in range(arr_len)] for i in range(nb)]
double_value = [[np.double(j) for j in range(arr_len)] for i in range(nb)]
string_value = [[str(j) for j in range(arr_len)] for i in range(nb)]

data = [pk_value, vector_value,
int8_value,int16_value, int32_value, int64_value,
bool_value,
float_value,
double_value,
string_value
]

#collection.insert(data)

data = pd.DataFrame({
'int64': pk_value,
'float_vector': vector_value,
"int8_array": int8_value,
"int16_array": int16_value,
"int32_array": int32_value,
"int64_array": int64_value,
"bool_array": bool_value,
"float_array": float_value,
"double_array": double_value,
"string_array": string_value
})
collection.insert(data)

index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

collection.create_index("float_vector", index)
collection.load()

res = collection.query("int64 >= 0", output_fields=["int8_array"])
for hits in res:
print(hits)
10 changes: 5 additions & 5 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def entity_to_json_arr(entity: Dict):
return convert_to_json_arr(entity.get("values", []))


def convert_to_array_arr(objs: List[Any]):
return [convert_to_array_arr(obj) for obj in objs]
def convert_to_array_arr(objs: List[Any], field_info: Any):
return [convert_to_array(obj, field_info) for obj in objs]


def convert_to_array(obj: List[Any], field_info: Any):
Expand Down Expand Up @@ -100,8 +100,8 @@ def convert_to_array(obj: List[Any], field_info: Any):
)


def entity_to_array_arr(entity: List[Any]):
return convert_to_array_arr(entity.get("values", []))
def entity_to_array_arr(entity: List[Any], field_info: Any):
return convert_to_array_arr(entity.get("values", []), field_info)


def pack_field_value_to_field_data(field_value: Any, field_data: Any, field_info: Any):
Expand Down Expand Up @@ -166,7 +166,7 @@ def entity_to_field_data(entity: Any, field_info: Any):
elif entity_type == DataType.JSON:
field_data.scalars.json_data.data.extend(entity_to_json_arr(entity))
elif entity_type == DataType.ARRAY:
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity))
field_data.scalars.array_data.data.extend(entity_to_array_arr(entity, field_info))
else:
raise ParamError(message=f"UnSupported data type: {entity_type}")

Expand Down
1 change: 0 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,6 @@ def batch_insert(
m = MutationResult(response)
ts_utils.update_collection_ts(collection_name, m.timestamp)
return m

raise MilvusException(
response.status.code, response.status.reason, response.status.error_code
)
Expand Down
10 changes: 5 additions & 5 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ def create_collection_with_schema(
self,
collection_name: str,
schema: CollectionSchema,
index_param: Dict,
index_params: Dict,
timeout: Optional[float] = None,
**kwargs,
):
schema.verify()
if kwargs.get("auto_id", True):
if kwargs.get("auto_id", False):
schema.auto_id = True
if kwargs.get("enable_dynamic_field", False):
schema.enable_dynamic_field = True
schema.verify()

index_param = index_param or {}
vector_field_name = index_param.pop("field_name", "")
index_params = index_params or {}
vector_field_name = index_params.pop("field_name", "")
if not vector_field_name:
schema_dict = schema.to_dict()
vector_field_name = self._get_vector_field_name(schema_dict)
Expand All @@ -520,7 +520,7 @@ def create_collection_with_schema(
logger.error("Failed to create collection: %s", collection_name)
raise ex from ex

self._create_index(collection_name, vector_field_name, index_param, timeout=timeout)
self._create_index(collection_name, vector_field_name, index_params, timeout=timeout)
self._load(collection_name, timeout=timeout)

def close(self):
Expand Down
151 changes: 49 additions & 102 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ def _check_fields(self):
primary_field_name = self._kwargs.get("primary_field", None)
partition_key_field_name = self._kwargs.get("partition_key_field", None)
for field in self._fields:
if primary_field_name == field.name:
if primary_field_name and primary_field_name == field.name:
field.is_primary = True
if partition_key_field_name == field.name:

if partition_key_field_name and partition_key_field_name == field.name:
field.is_partition_key = True

if field.is_primary:
Expand Down Expand Up @@ -403,17 +404,58 @@ def check_is_row_based(data: Union[List[List], List[Dict], Dict, pd.DataFrame])
return False


def _check_insert_data(data: Union[List[List], pd.DataFrame]):
if not isinstance(data, (pd.DataFrame, list)):
raise DataTypeNotSupportException(
message="The type of data should be list or pandas.DataFrame"
)
is_dataframe = isinstance(data, pd.DataFrame)
for col in data:
if not is_dataframe and not is_list_like(col):
raise DataTypeNotSupportException(message="data should be a list of list")


def _check_data_schema_cnt(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

field_cnt = len(tmp_fields)
is_dataframe = isinstance(data, pd.DataFrame)
data_cnt = len(data.columns) if is_dataframe else len(data)
if field_cnt != data_cnt:
message = (
f"The data don't match with schema fields, expect {field_cnt} list, got {len(data)}"
)
if is_dataframe:
i_name = [f.name for f in tmp_fields]
t_name = list(data.columns)
message = f"The fields don't match with schema fields, expected: {i_name}, got {t_name}"

raise DataNotMatchException(message=message)

if is_dataframe:
for x, y in zip(list(data.columns), tmp_fields):
if x != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x}"
)


def check_insert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if schema is None:
raise SchemaNotReadyException(message="Schema shouldn't be None")
if schema.auto_id and isinstance(data, pd.DataFrame) and schema.primary_field.name in data:
if not data[schema.primary_field.name].isnull().all():
msg = f"Expect no data for auto_id primary field: {schema.primary_field.name}"
raise DataNotMatchException(message=msg)
data = data.drop(schema.primary_field.name, axis=1)
columns = list(data.columns)
columns.remove(schema.primary_field)
data = data[[columns]]

infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data)
check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame)
_check_data_schema_cnt(schema, data)
_check_insert_data(data)


def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
Expand All @@ -422,78 +464,8 @@ def check_upsert_schema(schema: CollectionSchema, data: Union[List[List], pd.Dat
if schema.auto_id:
raise UpsertAutoIDTrueException(message=ExceptionsMessage.UpsertAutoIDTrue)

infer_fields, tmp_fields, is_data_frame = parse_fields_from_data(schema, data)
check_infer_fields_valid(infer_fields, tmp_fields, is_data_frame)


def parse_fields_from_data(schema: CollectionSchema, data: Union[List[List], pd.DataFrame]):
if not isinstance(data, (pd.DataFrame, list)):
raise DataTypeNotSupportException(
message="The type of data should be list or pandas.DataFrame"
)

if isinstance(data, pd.DataFrame):
return parse_fields_from_dataframe(schema, data)

tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(tmp_fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)

infer_fields = []
for i, field in enumerate(tmp_fields):
try:
d = data[i]
if not is_list_like(d):
raise DataTypeNotSupportException(message="data should be a list of list")
try:
elem = d[0]
infer_fields.append(FieldSchema("", infer_dtype_bydata(elem)))
# if pass in [] or None, considering to be passed in order according to the schema
except IndexError:
infer_fields.append(FieldSchema("", field.dtype))
# the last missing part of data is also completed in order according to the schema
except IndexError:
infer_fields.append(FieldSchema("", field.dtype))

index = len(tmp_fields)
while index < len(data):
fields = FieldSchema("", infer_dtype_bydata(data[index][0]))
infer_fields.append(fields)
index = index + 1

return infer_fields, tmp_fields, False


def parse_fields_from_dataframe(schema: CollectionSchema, df: pd.DataFrame):
col_names, data_types, column_params_map = prepare_fields_from_dataframe(df)
tmp_fields = copy.deepcopy(schema.fields)
for i, field in enumerate(schema.fields):
if field.is_primary and field.auto_id:
tmp_fields.pop(i)
infer_fields = []
for field in tmp_fields:
# if no data pass in, considering to be passed in order according to the schema
if field.name not in col_names:
field_schema = FieldSchema(field.name, field.dtype)
col_names.append(field.name)
data_types.append(field.dtype)
infer_fields.append(field_schema)
else:
type_params = column_params_map.get(field.name, {})
field_schema = FieldSchema(
field.name, data_types[col_names.index(field.name)], **type_params
)
infer_fields.append(field_schema)

infer_name = [f.name for f in infer_fields]
for name, dtype in zip(col_names, data_types):
if name not in infer_name:
type_params = column_params_map.get(name, {})
field_schema = FieldSchema(name, dtype, **type_params)
infer_fields.append(field_schema)

return infer_fields, tmp_fields, True
_check_data_schema_cnt(schema, data)
_check_insert_data(data)


def construct_fields_from_dataframe(df: pd.DataFrame) -> List[FieldSchema]:
Expand Down Expand Up @@ -536,31 +508,6 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
return col_names, data_types, column_params_map


def check_infer_fields_valid(
infer_fields: List[FieldSchema],
tmp_fields: List,
is_data_frame: bool,
):
if len(infer_fields) != len(tmp_fields):
i_name = [f.name for f in infer_fields]
t_name = [f.name for f in tmp_fields]
raise DataNotMatchException(
message=f"The fields don't match with schema fields, expected: {t_name}, got {i_name}"
)

for x, y in zip(infer_fields, tmp_fields):
if is_data_frame and x.name != y.name:
raise DataNotMatchException(
message=f"The name of field don't match, expected: {y.name}, got {x.name}"
)
if x.dtype != y.dtype:
msg = (
f"The data type of field {y.name} doesn't match, "
f"expected: {y.dtype.name}, got {x.dtype.name}"
)
raise DataNotMatchException(message=msg)


def check_schema(schema: CollectionSchema):
if schema is None:
raise SchemaNotReadyException(message=ExceptionsMessage.NoSchema)
Expand Down

0 comments on commit 8dd7c24

Please sign in to comment.