Skip to content

Commit

Permalink
Support for anisotropic data.
Browse files Browse the repository at this point in the history
  • Loading branch information
almarklein committed Nov 4, 2020
1 parent 8d91b7b commit a8b7388
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 46 deletions.
101 changes: 66 additions & 35 deletions dash_3d_viewer/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dash.dependencies import Input, Output, State, ALL
from dash_core_components import Graph, Slider, Store

from .utils import img_array_to_uri, get_thumbnail_size_from_shape
from .utils import img_array_to_uri, get_thumbnail_size_from_shape, shape3d_to_size2d


class DashVolumeSlicer:
Expand All @@ -13,6 +13,11 @@ class DashVolumeSlicer:
Parameters:
app (dash.Dash): the Dash application instance.
volume (ndarray): the 3D numpy array to slice through.
The dimensions are assumed to be in zyx order.
spacing (tuple of floats): The voxel size for each dimension (zyx).
The spacing and origin are applied to make the slice drawn in
"scene space" rather than "voxel space".
origin (tuple of floats): The offset for each dimension (zyx).
axis (int): the dimension to slice in. Default 0.
scene_id (str): the scene that this slicer is part of. Slicers
that have the same scene-id show each-other's positions with
Expand All @@ -38,14 +43,21 @@ class DashVolumeSlicer:

_global_slicer_counter = 0

def __init__(self, app, volume, axis=0, scene_id=None):
def __init__(
self, app, volume, *, spacing=None, origin=None, axis=0, scene_id=None
):
# todo: also implement xyz dim order?
if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
self._app = app
# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
self._volume = volume
spacing = (1, 1, 1) if spacing is None else spacing
spacing = float(spacing[0]), float(spacing[1]), float(spacing[2])
origin = (0, 0, 0) if origin is None else origin
origin = float(origin[0]), float(origin[1]), float(origin[2])
# Check and store axis
if not (isinstance(axis, int) and 0 <= axis <= 2):
raise ValueError("The given axis must be 0, 1, or 2.")
Expand All @@ -60,20 +72,26 @@ def __init__(self, app, volume, axis=0, scene_id=None):
DashVolumeSlicer._global_slicer_counter += 1
self.context_id = "slicer_" + str(DashVolumeSlicer._global_slicer_counter)

# Get the slice size (width, height), and max index
arr_shape = list(volume.shape)
arr_shape.pop(self._axis)
self._slice_size = tuple(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1
# Prepare slice info
info = {
"shape": tuple(volume.shape),
"axis": self._axis,
"size": shape3d_to_size2d(volume.shape, axis),
"origin": shape3d_to_size2d(origin, axis),
"spacing": shape3d_to_size2d(spacing, axis),
}

# Prep low-res slices
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
thumbnail_size = get_thumbnail_size_from_shape(
(info["size"][1], info["size"][0]), 32
)
thumbnails = [
img_array_to_uri(self._slice(i), thumbnail_size)
for i in range(self._max_index + 1)
for i in range(info["size"][2])
]
info["lowres_size"] = thumbnail_size

# Create a placeholder trace
# Create traces
# todo: can add "%{z[0]}", but that would be the scaled value ...
image_trace = Image(
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
Expand Down Expand Up @@ -106,22 +124,20 @@ def __init__(self, app, volume, axis=0, scene_id=None):
config={"scrollZoom": True},
)
# Create a slider object that the user can put in the layout (or not)
# todo: use tooltip to show current value?
self.slider = Slider(
id=self._subid("slider"),
min=0,
max=self._max_index,
max=info["size"][2] - 1,
step=1,
value=self._max_index // 2,
value=info["size"][2] // 2,
tooltip={"always_visible": False, "placement": "left"},
updatemode="drag",
)
# Create the stores that we need (these must be present in the layout)
self.stores = [
Store(
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
),
Store(id=self._subid("info"), data=info),
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("position"), data=0),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
Expand Down Expand Up @@ -175,6 +191,17 @@ def _create_client_callbacks(self):
[Input(self._subid("slider"), "value")],
)

app.clientside_callback(
"""
function update_position(index, info) {
return info.origin[2] + index * info.spacing[2];
}
""",
Output(self._subid("position"), "data"),
[Input(self._subid("index"), "data")],
[State(self._subid("info"), "data")],
)

app.clientside_callback(
"""
function handle_slice_index(index) {
Expand Down Expand Up @@ -205,7 +232,7 @@ def _create_client_callbacks(self):

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
Expand All @@ -214,18 +241,18 @@ def _create_client_callbacks(self):
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
let x0 = 0, y0 = 0, dx = 1, dy = 1;
let x0 = info.origin[0], y0 = info.origin[1];
let dx = info.spacing[0], dy = info.spacing[1];
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
// Maybe we do not need an update
if (!data) {
data = lowres[index];
// Scale the image to take the exact same space as the full-res
// version. It's not correct, but it looks better ...
// slice_size = full_w, full_h, low_w, low_h
dx = slice_size[0] / slice_size[2];
dy = slice_size[1] / slice_size[3];
x0 = 0.5 * dx - 0.5;
y0 = 0.5 * dy - 0.5;
dx *= info.size[0] / info.lowres_size[0];
dy *= info.size[1] / info.lowres_size[1];
x0 += 0.5 * dx - 0.5 * info.spacing[0];
y0 += 0.5 * dy - 0.5 * info.spacing[1];
}
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
return window.dash_clientside.no_update;
Expand Down Expand Up @@ -253,7 +280,7 @@ def _create_client_callbacks(self):
[
State(self._subid("graph"), "figure"),
State(self._subid("_slice-data-lowres"), "data"),
State(self._subid("_slice-size"), "data"),
State(self._subid("info"), "data"),
],
)

Expand All @@ -266,18 +293,22 @@ def _create_client_callbacks(self):
# * match any of the selected axii
app.clientside_callback(
"""
function handle_indicator(indices1, indices2, slice_size, current) {
let w = slice_size[0], h = slice_size[1];
let dx = w / 20, dy = h / 20;
function handle_indicator(positions1, positions2, info, current) {
let x0 = info.origin[0], y0 = info.origin[1];
let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1];
x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1];
let d = ((x1 - x0) + (y1 - y0)) * 0.5 * 0.05;
let version = (current.version || 0) + 1;
let x = [], y = [];
for (let index of indices1) {
x.push(...[-dx, -1, null, w, w + dx, null]);
y.push(...[index, index, index, index, index, index]);
for (let pos of positions1) {
// x relative to our slice, y in scene-coords
x.push(...[x0 - d, x0, null, x1, x1 + d, null]);
y.push(...[pos, pos, pos, pos, pos, pos]);
}
for (let index of indices2) {
x.push(...[index, index, index, index, index, index]);
y.push(...[-dy, -1, null, h, h + dy, null]);
for (let pos of positions2) {
// x in scene-coords, y relative to our slice
x.push(...[pos, pos, pos, pos, pos, pos]);
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
}
return {
type: 'scatter',
Expand All @@ -296,15 +327,15 @@ def _create_client_callbacks(self):
{
"scene": self.scene_id,
"context": ALL,
"name": "index",
"name": "position",
"axis": axis,
},
"data",
)
for axis in axii
],
[
State(self._subid("_slice-size"), "data"),
State(self._subid("info"), "data"),
State(self._subid("_indicators"), "data"),
],
)
11 changes: 11 additions & 0 deletions dash_3d_viewer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,14 @@ def get_thumbnail_size_from_shape(shape, base_size):
img_pil = PIL.Image.fromarray(img_array)
img_pil.thumbnail((base_size, base_size))
return img_pil.size


def shape3d_to_size2d(shape, axis):
"""Turn a 3d shape (z, y, x) into a local (x', y', z'),
where z' represents the dimension indicated by axis.
"""
shape = list(shape)
axis_value = shape.pop(axis)
size = list(reversed(shape))
size.append(axis_value)
return tuple(size)
27 changes: 18 additions & 9 deletions examples/slicer_with_1_plus_2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
This demonstrates how multiple indicators can be shown per axis.
Sharing the same scene_id is enough for the slicers to show each-others
position. If the same volume object is given, it works by default,
position. If the same volume object would be given, it works by default,
because the default scene_id is a hash of the volume object. Specifying
a scene_id provides slice position indicators even when slicing through
different volumes.
Further, this example has one slider showing data with different spacing.
Note how the indicators represent the actual position in "scene coordinates".
"""

import dash
Expand All @@ -17,22 +21,29 @@

app = dash.Dash(__name__)

vol = imageio.volread("imageio:stent.npz")
slicer1 = DashVolumeSlicer(app, vol, axis=1, scene_id="myscene")
slicer2 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
slicer3 = DashVolumeSlicer(app, vol, axis=0, scene_id="myscene")
vol1 = imageio.volread("imageio:stent.npz")

vol2 = vol1[::3, ::2, :]
spacing = 3, 2, 1
origin = 110, 120, 140


slicer1 = DashVolumeSlicer(app, vol1, axis=1, origin=origin, scene_id="myscene")
slicer2 = DashVolumeSlicer(app, vol1, axis=0, origin=origin, scene_id="myscene")
slicer3 = DashVolumeSlicer(
app, vol2, axis=0, origin=origin, spacing=spacing, scene_id="myscene"
)

app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
[
html.H1("Coronal"),
slicer1.graph,
html.Br(),
slicer1.slider,
*slicer1.stores,
]
Expand All @@ -41,7 +52,6 @@
[
html.H1("Transversal 1"),
slicer2.graph,
html.Br(),
slicer2.slider,
*slicer2.stores,
]
Expand All @@ -51,7 +61,6 @@
[
html.H1("Transversal 2"),
slicer3.graph,
html.Br(),
slicer3.slider,
*slicer3.stores,
]
Expand Down
2 changes: 1 addition & 1 deletion examples/slicer_with_2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
Expand Down
2 changes: 1 addition & 1 deletion examples/slicer_with_3_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
"gridTemplateColumns": "40% 40%",
},
children=[
html.Div(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dash_3d_viewer.utils import shape3d_to_size2d

from pytest import raises


def test_shape3d_to_size2d():
# shape -> z, y, x
# size -> x, y, out-of-plane
assert shape3d_to_size2d((12, 13, 14), 0) == (14, 13, 12)
assert shape3d_to_size2d((12, 13, 14), 1) == (14, 12, 13)
assert shape3d_to_size2d((12, 13, 14), 2) == (13, 12, 14)

with raises(IndexError):
shape3d_to_size2d((12, 13, 14), 3)

0 comments on commit a8b7388

Please sign in to comment.