Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support numpy type value for bulkwriter #1718

Merged
merged 1 commit into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
),
) as remote_writer:
print("Append rows")
for i in range(10000):
batch_count = 10000
for i in range(batch_count):
row = {
"id": i,
"bool": True if i%5 == 0 else False,
Expand All @@ -249,6 +250,23 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
}
remote_writer.append_row(row)

# append rows by numpy type
for i in range(batch_count):
remote_writer.append_row({
"id": np.int64(i+batch_count),
"bool": True if i % 3 == 0 else False,
"int8": np.int8(i%128),
"int16": np.int16(i%1000),
"int32": np.int32(i%100000),
"int64": np.int64(i),
"float": np.float32(i/3),
"double": np.float64(i/7),
"varchar": f"varchar_{i}",
"json": json.dumps({"dummy": i, "ok": f"name_{i}"}),
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
})

print("Generate data files...")
remote_writer.commit()
print(f"Data files have been uploaded: {remote_writer.batch_files}")
Expand Down
26 changes: 22 additions & 4 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import json
import logging
from threading import Lock

import numpy as np

from pymilvus.client.types import DataType
from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema
Expand Down Expand Up @@ -85,6 +88,16 @@ def commit(self, **kwargs):
def data_path(self):
return ""

def _try_convert_json(self, field_name: str, obj: object):
if isinstance(obj, str):
try:
return json.loads(obj)
except Exception as e:
self._throw(
f"Illegal JSON value for field '{field_name}', type mismatch or illegal format, error: {e}"
)
return obj

def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)
Expand All @@ -109,10 +122,12 @@ def _verify_row(self, row: dict):
dtype = DataType(field.dtype)
validator = TYPE_VALIDATOR[dtype.name]
if dtype in {DataType.BINARY_VECTOR, DataType.FLOAT_VECTOR}:
if isinstance(row[field.name], np.ndarray):
row[field.name] = row[field.name].tolist()
dim = field.params["dim"]
if not validator(row[field.name], dim):
self._throw(
f"Illegal vector data for vector field: '{dtype.name}',"
f"Illegal vector data for vector field: '{field.name}',"
f" dim is not {dim} or type mismatch"
)

Expand All @@ -126,20 +141,23 @@ def _verify_row(self, row: dict):
max_len = field.params["max_length"]
if not validator(row[field.name], max_len):
self._throw(
f"Illegal varchar value for field '{dtype.name}',"
f"Illegal varchar value for field '{field.name}',"
f" length exceeds {max_len} or type mismatch"
)

row_size = row_size + len(row[field.name])
elif dtype == DataType.JSON:
row[field.name] = self._try_convert_json(field.name, row[field.name])
if not validator(row[field.name]):
self._throw(f"Illegal varchar value for field '{dtype.name}', type mismatch")
self._throw(f"Illegal JSON value for field '{field.name}', type mismatch")

row_size = row_size + len(row[field.name])
else:
if isinstance(row[field.name], np.generic):
row[field.name] = row[field.name].item()
if not validator(row[field.name]):
self._throw(
f"Illegal scalar value for field '{dtype.name}', value overflow or type mismatch"
f"Illegal scalar value for field '{field.name}', value overflow or type mismatch"
)

row_size = row_size + TYPE_SIZE[dtype.name]
Expand Down
Loading