From aea2f16ea26b214c37a50fca99c6744d405e449c Mon Sep 17 00:00:00 2001 From: Ko Sugawara Date: Wed, 27 Mar 2024 22:37:43 +0900 Subject: [PATCH] Fix issue in data loading scale information was not handled properly. --- elephant-core/elephant/datasets.py | 65 ++++++++++++++---------------- 1 file changed, 30 insertions(+), 35 deletions(-) diff --git a/elephant-core/elephant/datasets.py b/elephant-core/elephant/datasets.py index 40b90fa..e1a30ef 100644 --- a/elephant-core/elephant/datasets.py +++ b/elephant-core/elephant/datasets.py @@ -57,8 +57,7 @@ def profile(func): return func -def _load_image(za_input, timepoint, use_median=False, img_size=None): - img = za_input[timepoint].astype('float32') +def _load_image(img, use_median=False, img_size=None): if use_median and img.ndim == 3: global_median = np.median(img) for z in range(img.shape[0]): @@ -86,9 +85,9 @@ def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False, fpath = Path(memmap_dir) / f'{key}.dat' lock = FileLock(str(fpath) + '.lock') with lock: - if not fpath.exists(): - logger().info(f'creating {fpath}') - fpath.parent.mkdir(parents=True, exist_ok=True) + if not fpath_org.exists(): + logger().info(f'creating {fpath_org}') + fpath_org.parent.mkdir(parents=True, exist_ok=True) img_org = np.memmap( fpath_org, dtype='float32', @@ -96,28 +95,27 @@ def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False, shape=za.shape[1:] ) img_org[:] = za[timepoint].astype('float32') - if img_size is None: - img = img_org - else: - img = np.memmap( - fpath, - dtype='float32', - mode='w+', - shape=img_size - ) - img[:] = F.interpolate( - torch.from_numpy(img_org)[None, None], - size=img_size, - mode='trilinear' if img.ndim == 3 else 'bilinear', - align_corners=True, - )[0, 0].numpy() - if use_median and img.ndim == 3: - global_median = np.median(img) - for z in range(img.shape[0]): - slice_median = np.median(img[z]) - if 0 < slice_median: - img[z] -= slice_median - global_median - img = normalize_zero_one(img) + else: + img_org = np.memmap( + fpath_org, + dtype='float32', + mode='c', + shape=za.shape[1:] + ) + if not fpath.exists(): + logger().info(f'creating {fpath}') + fpath.parent.mkdir(parents=True, exist_ok=True) + img = np.memmap( + fpath, + dtype='float32', + mode='w+', + shape=img_size + ) + img[:] = _load_image( + img_org, + use_median=use_median, + img_size=img_size, + ) logger().info(f'loading from {fpath}') return np.memmap( fpath, @@ -126,14 +124,11 @@ def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False, shape=za.shape[1:] if img_size is None else img_size ) else: - img = za[timepoint].astype('float32') - if use_median and img.ndim == 3: - global_median = np.median(img) - for z in range(img.shape[0]): - slice_median = np.median(img[z]) - if 0 < slice_median: - img[z] -= slice_median - global_median - img = normalize_zero_one(img) + img = _load_image( + za[timepoint].astype('float32'), + use_median=use_median, + img_size=img_size, + ) return img