Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulkinsert writer #1664

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import os
import json
import random
import threading

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("example_bulkwriter")

from pymilvus import (
connections,
FieldSchema, CollectionSchema, DataType,
Collection,
utility,
LocalBulkWriter,
RemoteBulkWriter,
BulkFileType,
bulk_import,
get_import_progress,
list_import_jobs,
)

# minio
MINIO_ADDRESS = "0.0.0.0:9000"
MINIO_SECRET_KEY = "minioadmin"
MINIO_ACCESS_KEY = "minioadmin"

# milvus
HOST = '127.0.0.1'
PORT = '19530'

COLLECTION_NAME = "test_abc"
DIM = 256

def create_connection():
print(f"\nCreate connection...")
connections.connect(host=HOST, port=PORT)
print(f"\nConnected")


def build_collection():
if utility.has_collection(COLLECTION_NAME):
utility.drop_collection(COLLECTION_NAME)

field1 = FieldSchema(name="id", dtype=DataType.INT64, auto_id=True, is_primary=True)
field2 = FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=DIM)
field3 = FieldSchema(name="desc", dtype=DataType.VARCHAR, max_length=100)
schema = CollectionSchema(fields=[field1, field2, field3])
collection = Collection(name=COLLECTION_NAME, schema=schema)
print("Collection created")
return collection.schema

def test_local_writer_json(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema,
local_path="/tmp/bulk_data",
segment_size=4*1024*1024,
file_type=BulkFileType.JSON_RB,
)
for i in range(10):
local_writer.append_row({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

local_writer.commit()
print("test local writer done!")
print(local_writer.data_path)
return local_writer.data_path

def test_local_writer_npy(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema, local_path="/tmp/bulk_data", segment_size=4*1024*1024)
for i in range(10000):
local_writer.append_row({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

local_writer.commit()
print("test local writer done!")
print(local_writer.data_path)
return local_writer.data_path


def _append_row(writer: LocalBulkWriter, begin: int, end: int):
for i in range(begin, end):
writer.append_row({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

def test_parallel_append(schema: CollectionSchema):
local_writer = LocalBulkWriter(schema=schema,
local_path="/tmp/bulk_data",
segment_size=1000 * 1024 * 1024,
file_type=BulkFileType.JSON_RB,
)
threads = []
thread_count = 100
rows_per_thread = 1000
for k in range(thread_count):
x = threading.Thread(target=_append_row, args=(local_writer, k*rows_per_thread, (k+1)*rows_per_thread,))
threads.append(x)
x.start()
print(f"Thread '{x.name}' started")

for th in threads:
th.join()

local_writer.commit()
print(f"Append finished, {thread_count*rows_per_thread} rows")
file_path = os.path.join(local_writer.data_path, "1.json")
with open(file_path, 'r') as file:
data = json.load(file)

print("Verify the output content...")
rows = data['rows']
assert len(rows) == thread_count*rows_per_thread
for i in range(len(rows)):
row = rows[i]
assert row['desc'] == f"description_{row['id']}"


def test_remote_writer(schema: CollectionSchema):
remote_writer = RemoteBulkWriter(schema=schema,
remote_path="bulk_data",
local_path="/tmp/bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
endpoint=MINIO_ADDRESS,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
segment_size=50 * 1024 * 1024,
)

for i in range(10000):
if i % 1000 == 0:
logger.info(f"{i} rows has been append to remote writer")
remote_writer.append_row({"id": i, "vector": [random.random() for _ in range(DIM)], "desc": f"description_{i}"})

remote_writer.commit()
print("test remote writer done!")
print(remote_writer.data_path)
return remote_writer.data_path


def test_cloud_bulkinsert():
url = "https://_your_cloud_server_url_"
cluster_id = "_your_cloud_instance_id_"

print(f"===================== import files to cloud vectordb ====================")
object_url = "_your_object_storage_service_url_"
object_url_access_key = "_your_object_storage_service_access_key_"
object_url_secret_key = "_your_object_storage_service_secret_key_"
resp = bulk_import(
url=url,
object_url=object_url,
access_key=object_url_access_key,
secret_key=object_url_secret_key,
cluster_id=cluster_id,
collection_name=COLLECTION_NAME,
)
print(resp)

print(f"===================== get import job progress ====================")
job_id = resp['data']['jobId']
resp = get_import_progress(
url=url,
job_id=job_id,
cluster_id=cluster_id,
)
print(resp)

print(f"===================== list import jobs ====================")
resp = list_import_jobs(
url=url,
cluster_id=cluster_id,
page_size=10,
current_page=1,
)
print(resp)


if __name__ == '__main__':
create_connection()
schema = build_collection()

test_local_writer_json(schema)
test_local_writer_npy(schema)
test_remote_writer(schema)
test_parallel_append(schema)

# test_cloud_bulkinsert()

22 changes: 22 additions & 0 deletions pymilvus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

from .bulk_writer.bulk_import import (
bulk_import,
get_import_progress,
list_import_jobs,
)

# bulk writer
from .bulk_writer.constants import (
BulkFileType,
)
from .bulk_writer.local_bulk_writer import (
LocalBulkWriter,
)
from .bulk_writer.remote_bulk_writer import (
RemoteBulkWriter,
)
from .client import __version__
from .client.prepare import Prepare
from .client.stub import Milvus
Expand Down Expand Up @@ -124,4 +140,10 @@
"ResourceGroupInfo",
"Connections",
"IndexType",
"BulkFileType",
"LocalBulkWriter",
"RemoteBulkWriter",
"bulk_import",
"get_import_progress",
"list_import_jobs",
]
Empty file.
148 changes: 148 additions & 0 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (C) 2019-2023 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.

import json
import logging
from pathlib import Path

import numpy as np

from pymilvus.exceptions import MilvusException
from pymilvus.orm.schema import CollectionSchema

from .constants import (
DYNAMIC_FIELD_NAME,
BulkFileType,
)

logger = logging.getLogger("bulk_buffer")
logger.setLevel(logging.DEBUG)


class Buffer:
def __init__(
self,
schema: CollectionSchema,
file_type: BulkFileType = BulkFileType.NPY,
):
self._buffer = {}
self._file_type = file_type
for field in schema.fields:
self._buffer[field.name] = []

if len(self._buffer) == 0:
self._throw("Illegal collection schema: fields list is empty")

# dynamic field, internal name is '$meta'
if schema.enable_dynamic_field:
self._buffer[DYNAMIC_FIELD_NAME] = []

@property
def row_count(self) -> int:
if len(self._buffer) == 0:
return 0

for k in self._buffer:
return len(self._buffer[k])
return None

def _throw(self, msg: str):
logger.error(msg)
raise MilvusException(message=msg)

def append_row(self, row: dict):
dynamic_values = {}
if DYNAMIC_FIELD_NAME in row and not isinstance(row[DYNAMIC_FIELD_NAME], dict):
self._throw(f"Dynamic field '{DYNAMIC_FIELD_NAME}' value should be JSON format")

for k in row:
if k == DYNAMIC_FIELD_NAME:
dynamic_values.update(row[k])
continue

if k not in self._buffer:
dynamic_values[k] = row[k]
else:
self._buffer[k].append(row[k])

if DYNAMIC_FIELD_NAME in self._buffer:
self._buffer[DYNAMIC_FIELD_NAME].append(json.dumps(dynamic_values))

def persist(self, local_path: str) -> list:
# verify row count of fields are equal
row_count = -1
for k in self._buffer:
if row_count < 0:
row_count = len(self._buffer[k])
elif row_count != len(self._buffer[k]):
self._throw(
"Column `{}` row count {} doesn't equal to the first column row count {}".format(
k, len(self._buffer[k]), row_count
)
)

# output files
if self._file_type == BulkFileType.NPY:
return self._persist_npy(local_path)
if self._file_type == BulkFileType.JSON_RB:
return self._persist_json_rows(local_path)

self._throw(f"Unsupported file tpye: {self._file_type}")
return []

def _persist_npy(self, local_path: str):
Path(local_path).mkdir(exist_ok=True)

file_list = []
for k in self._buffer:
full_file_name = Path(local_path).joinpath(k + ".npy")
file_list.append(full_file_name)
try:
np.save(full_file_name, self._buffer[k])
except Exception as e:
self._throw(f"Failed to persist column-based file {full_file_name}, error: {e}")

logger.info(f"Successfully persist column-based file {full_file_name}")

if len(file_list) != len(self._buffer):
logger.error("Some of fields were not persisted successfully, abort the files")
for f in file_list:
Path(f).unlink()
Path(local_path).rmdir()
file_list.clear()
self._throw("Some of fields were not persisted successfully, abort the files")

return file_list

def _persist_json_rows(self, local_path: str):
rows = []
row_count = len(next(iter(self._buffer.values())))
row_index = 0
while row_index < row_count:
row = {}
for k, v in self._buffer.items():
row[k] = v[row_index]
rows.append(row)
row_index = row_index + 1

data = {
"rows": rows,
}
file_path = Path(local_path + ".json")
try:
with file_path.open("w") as json_file:
json.dump(data, json_file, indent=2)
except Exception as e:
self._throw(f"Failed to persist row-based file {file_path}, error: {e}")

logger.info(f"Successfully persist row-based file {file_path}")
return [file_path]
Loading
Loading