Skip to content

Commit

Permalink
fix dataloader unit-test defect and nng perf test defect
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Jan 18, 2023
1 parent 3fa5319 commit 8b1d69a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ding/framework/message_queue/perfs/perf_nng.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def recv_loop():
continue
elif topic == "f":
finish_tag.append(1)
send_t("f")
mq.stop()
return
else:
raise RuntimeError("Unkown topic")
Expand Down
7 changes: 5 additions & 2 deletions ding/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def __init__(
p, c = self.mp_context.Pipe()

# Async process (Main worker): Process data if num_workers <= 1; Assign job to other workers if num_workers > 1.
self.async_process = self.mp_context.Process(target=self._async_loop, args=(p, c))
self.async_process = self.mp_context.Process(target=self._async_loop, args=(p, c), name="async_process")
self.async_process.daemon = True
self.async_process.start()

# Get data thread: Get data from ``data_source`` and send it to ``async_process``.`
self.get_data_thread = threading.Thread(target=self._get_data, args=(p, c))
self.get_data_thread = threading.Thread(target=self._get_data, args=(p, c), name="get_data_thread")
self.get_data_thread.daemon = True
self.get_data_thread.start()

Expand Down Expand Up @@ -350,6 +350,9 @@ def close(self) -> None:
self.end_flag = True
self.async_process.terminate()
self.async_process.join()
if self.use_cuda:
self.cuda_thread.join()
self.get_data_thread.join()
if self.num_workers > 1:
for w in self.worker:
w.terminate()
Expand Down
2 changes: 1 addition & 1 deletion ding/utils/data/tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def entry(self, batch_size, num_workers, chunk_size, use_cuda):
assert total_data_time <= 7 * 0.008
dataloader.__del__()
time.sleep(0.5)
assert len(threading.enumerate()) <= 2, threading.enumerate()
assert len(threading.enumerate()) <= 3, threading.enumerate()

0 comments on commit 8b1d69a

Please sign in to comment.