Skip to content

Commit

Permalink
moved buffered_slice_writer (#135)
Browse files Browse the repository at this point in the history
* moved buffered_slice_writer

* reformatted files

* fixed test

* reformatted files
  • Loading branch information
MartinBuessemeyer authored Sep 27, 2019
1 parent bc85e70 commit 2d1668f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 36 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ wkw==0.0.8
requests
black
cluster_tools==1.41
natsort
natsort
psutil
4 changes: 1 addition & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def test_buffered_slice_writer():
mag = Mag(1)
dataset_path = os.path.join(dataset_dir, layer_name, mag.to_layer_name())

with BufferedSliceWriter(
dataset_dir, layer_name, dtype, bbox, origin, mag=mag
) as writer:
with BufferedSliceWriter(dataset_dir, layer_name, dtype, origin, mag=mag) as writer:
for i in range(13):
writer.write_slice(i, test_img)
with wkw.Dataset.open(dataset_path, wkw.Header(dtype)) as data:
Expand Down
122 changes: 90 additions & 32 deletions wkcuber/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import argparse
import cluster_tools
import json
import os
import psutil
from typing import List, Tuple, Union
from glob import iglob
from collections import namedtuple
from multiprocessing import cpu_count, Lock
import concurrent
from concurrent.futures import ProcessPoolExecutor
from os import path, getpid
from platform import python_version
from math import floor, ceil
from .mag import Mag
from logging import getLogger
import traceback

from .knossos import KnossosDataset, CUBE_EDGE_LEN
from .knossos import KnossosDataset
from .mag import Mag

WkwDatasetInfo = namedtuple(
"WkwDatasetInfo", ("dataset_path", "layer_name", "dtype", "mag")
Expand All @@ -26,6 +30,8 @@

BLOCK_LEN = 32

logger = getLogger(__name__)


def open_wkw(info, **kwargs):
if hasattr(info, "dtype"):
Expand Down Expand Up @@ -210,13 +216,16 @@ def wait_and_ensure_success(futures):
class BufferedSliceWriter(object):
def __init__(
self,
dataset_path,
layer_name,
dataset_path: str,
layer_name: str,
dtype,
bounding_box,
origin,
buffer_size=32,
mag=Mag(1),
origin: Union[Tuple[int, int, int], List[int]],
# buffer_size specifies, how many slices should be aggregated until they are flushed.
buffer_size: int = 32,
# file_len specifies, how many buckets written per dimension into a wkw cube. Using 32,
# results in 1 GB/wkw file for 8-bit data
file_len: int = 32,
mag: Mag = Mag("1"),
):

self.dataset_path = dataset_path
Expand All @@ -225,9 +234,11 @@ def __init__(

layer_path = path.join(self.dataset_path, self.layer_name, mag.to_layer_name())

self.dataset = wkw.Dataset.open(layer_path, wkw.Header(dtype))
self.dtype = dtype
self.dataset = wkw.Dataset.open(
layer_path, wkw.Header(dtype, file_len=file_len)
)
self.origin = origin
self.bounding_box = bounding_box

self.buffer = []
self.current_z = None
Expand Down Expand Up @@ -255,33 +266,66 @@ def _write_buffer(self):
if len(self.buffer) == 0:
return

assert len(self.buffer) <= self.buffer_size
assert (
len(self.buffer) <= self.buffer_size
), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."

logging.debug(
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.buffer))
assert (
len(uniq_dtypes) == 1
), "The buffer of BufferedSliceWriter contains slices with differing dtype."
assert uniq_dtypes.pop() == self.dtype, (
"The buffer of BufferedSliceWriter contains slices with a dtype "
"which differs from the dtype with which the BufferedSliceWriter was instantiated."
)

logger.debug(
"({}) Writing {} slices at position {}.".format(
getpid(), len(self.buffer), self.buffer_start_z
)
)

origin_with_offset = self.origin.copy()
origin_with_offset[2] = self.buffer_start_z
x_max = max(slice.shape[0] for slice in self.buffer)
y_max = max(slice.shape[1] for slice in self.buffer)
self.buffer = [
np.pad(
slice,
mode="constant",
pad_width=[(0, x_max - slice.shape[0]), (0, y_max - slice.shape[1])],
log_memory_consumption()

try:
origin_with_offset = list(self.origin)
origin_with_offset[2] = self.buffer_start_z
x_max = max(slice.shape[0] for slice in self.buffer)
y_max = max(slice.shape[1] for slice in self.buffer)

self.buffer = [
np.pad(
slice,
mode="constant",
pad_width=[
(0, x_max - slice.shape[0]),
(0, y_max - slice.shape[1]),
],
)
for slice in self.buffer
]

data = np.concatenate(
[np.expand_dims(slice, 2) for slice in self.buffer], axis=2
)
for slice in self.buffer
]
data = np.concatenate(
[np.expand_dims(slice, 2) for slice in self.buffer], axis=2
)

self.dataset.write(origin_with_offset, data)
self.dataset.write(origin_with_offset, data)

except Exception as exc:
logger.error(
"({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
"slices at position {}. Original error is:\n{}:{}\n\nTraceback:".format(
getpid(),
len(self.buffer),
self.buffer_start_z,
type(exc).__name__,
exc,
)
)
traceback.print_tb(exc.__traceback__)
logger.error("\n")

self.buffer = []
raise exc
finally:
self.buffer = []

def close(self):

Expand All @@ -291,5 +335,19 @@ def close(self):
def __enter__(self):
return self

def __exit__(self, type, value, tb):
def __exit__(self, _type, _value, _tb):
self.close()


def log_memory_consumption(additional_output=""):
pid = os.getpid()
process = psutil.Process(pid)
logging.info(
"Currently consuming {:.2f} GB of memory ({:.2f} GB still available) "
"in process {}. {}".format(
process.memory_info().rss / 1024 ** 3,
psutil.virtual_memory().available / 1024 ** 3,
pid,
additional_output,
)
)

0 comments on commit 2d1668f

Please sign in to comment.