Skip to content

Commit

Permalink
Fix concurrent bug of bulkwriter (#1690)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Sep 14, 2023
1 parent acea079 commit 6b2653a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 54 deletions.
93 changes: 54 additions & 39 deletions examples/example_bulkwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,20 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):
try:
for i in range(begin, end):
writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"})
if i%100 == 0:
print(f"{threading.current_thread().name} inserted {i-begin} items")
except Exception as e:
print("failed to append row!")

local_writer = LocalBulkWriter(
schema=schema,
local_path="/tmp/bulk_writer",
segment_size=1000 * 1024 * 1024,
segment_size=128 * 1024 * 1024, # 128MB
file_type=BulkFileType.JSON_RB,
)
threads = []
thread_count = 100
rows_per_thread = 100
thread_count = 10
rows_per_thread = 1000
for k in range(thread_count):
x = threading.Thread(target=_append_row, args=(local_writer, k*rows_per_thread, (k+1)*rows_per_thread,))
threads.append(x)
Expand All @@ -169,17 +171,24 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):

local_writer.commit()
print(f"Append finished, {thread_count*rows_per_thread} rows")
file_path = os.path.join(local_writer.data_path, "1.json")
with open(file_path, 'r') as file:
data = json.load(file)

print("Verify the output content...")
rows = data['rows']
assert len(rows) == thread_count*rows_per_thread
for row in rows:
pa = row['path']
label = row['label']
assert pa.replace("path_", "") == label.replace("label_", "")

row_count = 0
batch_files = local_writer.batch_files
for batch in batch_files:
for file_path in batch:
with open(file_path, 'r') as file:
data = json.load(file)

rows = data['rows']
row_count = row_count + len(rows)
print(f"The file {file_path} contains {len(rows)} rows. Verify the content...")

for row in rows:
pa = row['path']
label = row['label']
assert pa.replace("path_", "") == label.replace("label_", "")

assert row_count == thread_count * rows_per_thread
print("Data is correct")


Expand All @@ -196,16 +205,23 @@ def test_remote_writer(schema: CollectionSchema):
),
segment_size=50 * 1024 * 1024,
) as remote_writer:
# read data from csv
read_sample_data("./data/train_embeddings.csv", remote_writer)
remote_writer.commit()

# append rows
for i in range(10000):
remote_writer.append_row({"path": f"path_{i}", "vector": gen_float_vector(), "label": f"label_{i}"})
remote_writer.commit()

batch_files = remote_writer.batch_files

print(f"Test remote writer done! output remote files: {batch_files}")


def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
print(f"\n===================== all types test ====================")
remote_writer = RemoteBulkWriter(
with RemoteBulkWriter(
schema=schema,
remote_path="bulk_data",
connect_param=RemoteBulkWriter.ConnectParam(
Expand All @@ -214,30 +230,29 @@ def test_all_types_writer(bin_vec: bool, schema: CollectionSchema)->list:
secret_key=MINIO_SECRET_KEY,
bucket_name="a-bucket",
),
)

print("Append rows")
for i in range(10000):
row = {
"id": i,
"bool": True if i%5 == 0 else False,
"int8": i%128,
"int16": i%1000,
"int32": i%100000,
"int64": i,
"float": i/3,
"double": i/7,
"varchar": f"varchar_{i}",
"json": {"dummy": i, "ok": f"name_{i}"},
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
}
remote_writer.append_row(row)

print("Generate data files...")
remote_writer.commit()
print(f"Data files have been uploaded: {remote_writer.batch_files}")
return remote_writer.batch_files
) as remote_writer:
print("Append rows")
for i in range(10000):
row = {
"id": i,
"bool": True if i%5 == 0 else False,
"int8": i%128,
"int16": i%1000,
"int32": i%100000,
"int64": i,
"float": i/3,
"double": i/7,
"varchar": f"varchar_{i}",
"json": {"dummy": i, "ok": f"name_{i}"},
"vector": gen_binary_vector() if bin_vec else gen_float_vector(),
f"dynamic_{i}": i,
}
remote_writer.append_row(row)

print("Generate data files...")
remote_writer.commit()
print(f"Data files have been uploaded: {remote_writer.batch_files}")
return remote_writer.batch_files


def test_call_bulkinsert(schema: CollectionSchema, batch_files: list):
Expand Down
10 changes: 6 additions & 4 deletions pymilvus/bulk_writer/bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def append_row(self, row: dict, **kwargs):
self._buffer.append_row(row)

def commit(self, **kwargs):
self._buffer_size = 0
self._buffer_row_count = 0
with self._buffer_lock:
self._buffer_size = 0
self._buffer_row_count = 0

@property
def data_path(self):
Expand Down Expand Up @@ -143,5 +144,6 @@ def _verify_row(self, row: dict):

row_size = row_size + TYPE_SIZE[dtype.name]

self._buffer_size = self._buffer_size + row_size
self._buffer_row_count = self._buffer_row_count + 1
with self._buffer_lock:
self._buffer_size = self._buffer_size + row_size
self._buffer_row_count = self._buffer_row_count + 1
40 changes: 29 additions & 11 deletions pymilvus/bulk_writer/local_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
import uuid
from pathlib import Path
from threading import Thread
from threading import Lock, Thread
from typing import Callable, Optional

from pymilvus.orm.schema import CollectionSchema
Expand Down Expand Up @@ -43,7 +43,9 @@ def __init__(
self._uuid = str(uuid.uuid4())
self._flush_count = 0
self._working_thread = {}
self._working_thread_lock = Lock()
self._local_files = []
self._make_dir()

@property
def uuid(self):
Expand All @@ -59,10 +61,12 @@ def __del__(self):
self._exit()

def _exit(self):
# remove the uuid folder
if Path(self._local_path).exists() and not any(Path(self._local_path).iterdir()):
Path(self._local_path).rmdir()
logger.info(f"Delete empty directory '{self._local_path}'")

# wait flush thread
if len(self._working_thread) > 0:
for k, th in self._working_thread.items():
logger.info(f"Wait flush thread '{k}' to finish")
Expand All @@ -79,38 +83,52 @@ def _make_dir(self):
def append_row(self, row: dict, **kwargs):
super().append_row(row, **kwargs)

if super().buffer_size > super().segment_size:
self.commit(_async=True)
# only one thread can enter this section to persist data,
# in the _flush() method, the buffer will be swapped to a new one.
# in anync mode, the flush thread is asynchronously, other threads can
# continue to append if the new buffer size is less than target size
with self._working_thread_lock:
if super().buffer_size > super().segment_size:
self.commit(_async=True)

def commit(self, **kwargs):
# _async=True, the flush thread is asynchronously
while len(self._working_thread) > 0:
logger.info("Previous flush action is not finished, waiting...")
time.sleep(0.5)
logger.info(
f"Previous flush action is not finished, {threading.current_thread().name} is waiting..."
)
time.sleep(1.0)

logger.info(
f"Prepare to flush buffer, row_count: {super().buffer_row_count}, size: {super().buffer_size}"
)
_async = kwargs.get("_async", False)
call_back = kwargs.get("call_back", None)

x = Thread(target=self._flush, args=(call_back,))
logger.info(f"Flush thread begin, name: {x.name}")
self._working_thread[x.name] = x
x.start()
if not _async:
logger.info("Wait flush to finish")
x.join()

super().commit() # reset the buffer size
logger.info(f"Commit done with async={_async}")

def _flush(self, call_back: Optional[Callable] = None):
self._make_dir()
self._working_thread[threading.current_thread().name] = threading.current_thread()
self._flush_count = self._flush_count + 1
target_path = Path.joinpath(self._local_path, str(self._flush_count))

old_buffer = super()._new_buffer()
file_list = old_buffer.persist(str(target_path))
self._local_files.append(file_list)
if call_back:
call_back(file_list)
if old_buffer.row_count > 0:
file_list = old_buffer.persist(str(target_path))
self._local_files.append(file_list)
if call_back:
call_back(file_list)

del self._working_thread[threading.current_thread().name]
logger.info(f"Flush thread done, name: {threading.current_thread().name}")

@property
def data_path(self):
Expand Down
6 changes: 6 additions & 0 deletions pymilvus/bulk_writer/remote_bulk_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def __enter__(self):

def __exit__(self, exc_type: object, exc_val: object, exc_tb: object):
super().__exit__(exc_type, exc_val, exc_tb)
# remove the temp folder "bulk_writer"
if Path(self._local_path).parent.exists() and not any(
Path(self._local_path).parent.iterdir()
):
Path(self._local_path).parent.rmdir()
logger.info(f"Delete empty directory '{Path(self._local_path).parent}'")

def _get_client(self):
try:
Expand Down

0 comments on commit 6b2653a

Please sign in to comment.