Skip to content

Commit

Permalink
Fix issue in data loading
Browse files Browse the repository at this point in the history
scale information was not handled properly.
  • Loading branch information
ksugar committed Mar 27, 2024
1 parent 9df6134 commit aea2f16
Showing 1 changed file with 30 additions and 35 deletions.
65 changes: 30 additions & 35 deletions elephant-core/elephant/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def profile(func):
return func


def _load_image(za_input, timepoint, use_median=False, img_size=None):
img = za_input[timepoint].astype('float32')
def _load_image(img, use_median=False, img_size=None):
if use_median and img.ndim == 3:
global_median = np.median(img)
for z in range(img.shape[0]):
Expand Down Expand Up @@ -86,38 +85,37 @@ def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False,
fpath = Path(memmap_dir) / f'{key}.dat'
lock = FileLock(str(fpath) + '.lock')
with lock:
if not fpath.exists():
logger().info(f'creating {fpath}')
fpath.parent.mkdir(parents=True, exist_ok=True)
if not fpath_org.exists():
logger().info(f'creating {fpath_org}')
fpath_org.parent.mkdir(parents=True, exist_ok=True)
img_org = np.memmap(
fpath_org,
dtype='float32',
mode='w+',
shape=za.shape[1:]
)
img_org[:] = za[timepoint].astype('float32')
if img_size is None:
img = img_org
else:
img = np.memmap(
fpath,
dtype='float32',
mode='w+',
shape=img_size
)
img[:] = F.interpolate(
torch.from_numpy(img_org)[None, None],
size=img_size,
mode='trilinear' if img.ndim == 3 else 'bilinear',
align_corners=True,
)[0, 0].numpy()
if use_median and img.ndim == 3:
global_median = np.median(img)
for z in range(img.shape[0]):
slice_median = np.median(img[z])
if 0 < slice_median:
img[z] -= slice_median - global_median
img = normalize_zero_one(img)
else:
img_org = np.memmap(
fpath_org,
dtype='float32',
mode='c',
shape=za.shape[1:]
)
if not fpath.exists():
logger().info(f'creating {fpath}')
fpath.parent.mkdir(parents=True, exist_ok=True)
img = np.memmap(
fpath,
dtype='float32',
mode='w+',
shape=img_size
)
img[:] = _load_image(
img_org,
use_median=use_median,
img_size=img_size,
)
logger().info(f'loading from {fpath}')
return np.memmap(
fpath,
Expand All @@ -126,14 +124,11 @@ def _get_memmap_or_load(za, timepoint, memmap_dir=None, use_median=False,
shape=za.shape[1:] if img_size is None else img_size
)
else:
img = za[timepoint].astype('float32')
if use_median and img.ndim == 3:
global_median = np.median(img)
for z in range(img.shape[0]):
slice_median = np.median(img[z])
if 0 < slice_median:
img[z] -= slice_median - global_median
img = normalize_zero_one(img)
img = _load_image(
za[timepoint].astype('float32'),
use_median=use_median,
img_size=img_size,
)
return img


Expand Down

0 comments on commit aea2f16

Please sign in to comment.