Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device agnostic compute for polarization #150

Merged
merged 11 commits into from
Jan 6, 2024
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


def device_params():
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
if torch.backends.mps.is_available():
devices.append("mps")
return "device", devices


_DEVICE = device_params()
14 changes: 9 additions & 5 deletions tests/models/test_inplane_oriented_thick_pol3D.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from waveorder import stokes
from tests.conftest import _DEVICE
from waveorder.models import inplane_oriented_thick_pol3d


Expand All @@ -16,23 +16,27 @@ def test_calculate_transfer_function():
assert intensity_to_stokes_matrix.shape == (4, 5)


def test_apply_inverse_transfer_function():
input_shape = (5, 10, 5, 5)
czyx_data = torch.rand(input_shape)
@pytest.mark.parametrize(*_DEVICE)
@pytest.mark.parametrize("estimate_bg", [True, False])
def test_apply_inverse_transfer_function(device, estimate_bg):
input_shape = (5, 10, 100, 100)
czyx_data = torch.rand(input_shape, device=device)

intensity_to_stokes_matrix = (
inplane_oriented_thick_pol3d.calculate_transfer_function(
swing=0.1,
scheme="5-State",
)
).to(device)
)

results = inplane_oriented_thick_pol3d.apply_inverse_transfer_function(
czyx_data=czyx_data,
intensity_to_stokes_matrix=intensity_to_stokes_matrix,
remove_estimated_background=estimate_bg,
)

assert len(results) == 4

for result in results:
assert result.shape == input_shape[1:]
assert result.device.type == device
42 changes: 42 additions & 0 deletions tests/test_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import torch

from tests.conftest import _DEVICE
from waveorder.correction import (
_fit_2d_polynomial_surface,
_grid_coordinates,
_sample_block_medians,
estimate_background,
)


def test_sample_block_medians():
image = torch.arange(4 * 5, dtype=torch.float).reshape(4, 5)
medians = _sample_block_medians(image, 2)
assert torch.allclose(
medians, torch.tensor([1, 3, 11, 13]).to(image.dtype)
)


def test_grid_coordinates():
image = torch.ones(15, 17)
coords = _grid_coordinates(image, 4)
assert coords.shape == (3 * 4, 2)


def test_fit_2d_polynomial_surface():
coords = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float)
values = torch.tensor([0, 1, 2, 3], dtype=torch.float)
surface = _fit_2d_polynomial_surface(coords, values, 1, (2, 2))
assert torch.allclose(surface, values.reshape(surface.shape), atol=1e-2)


@pytest.mark.parametrize("order", [1, 2, 3])
@pytest.mark.parametrize(*_DEVICE)
def test_estimate_background(order, device):
image = torch.rand(200, 200).to(device)
image[:100, :100] += 1
background = estimate_background(image, order=order, block_size=32)
assert 2.0 > background[50, 50] > 1.0
assert 1.5 > background[0, 100] > 0.5
assert 1.0 > background[150, 150] > 0.0
77 changes: 51 additions & 26 deletions tests/test_stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from waveorder import stokes

from .conftest import _DEVICE


def test_S2I_matrix():
S2I5 = stokes.calculate_stokes_to_intensity_matrix(0.1)
Expand Down Expand Up @@ -35,18 +37,26 @@ def test_I2S_matrix():
tt.assert_close(I, torch.eye(I.shape[0]))


def test_s12_to_orientation():
for orientation in torch.linspace(0, np.pi, 25)[:-1]: # skip endpoint
@pytest.mark.parametrize(*_DEVICE)
def test_s12_to_orientation(device):
for orientation in torch.linspace(0, np.pi, 25, device=device)[
:-1
]: # skip endpoint
orientation1 = stokes._s12_to_orientation(
np.sin(2 * orientation), -np.cos(2 * orientation)
torch.sin(2 * orientation), -torch.cos(2 * orientation)
)
tt.assert_close(orientation, orientation1)


def test_stokes_recon():
# NOTE: skip retardance = 0 and depolarization = 0 because orientationentation is not defined
for retardance in torch.arange(1e-3, 1, 0.1): # fractions of a wave
for orientation in torch.arange(0, np.pi, np.pi / 10): # radians
@pytest.mark.parametrize(*_DEVICE)
def test_stokes_recon(device):
# NOTE: skip retardance = 0 and depolarization = 0 because orientation is not defined
for retardance in torch.arange(
1e-3, 1, 0.1, device=device
): # fractions of a wave
for orientation in torch.arange(
0, np.pi, np.pi / 10, device=device
): # radians
for transmittance in [0.1, 10]:
# Test attenuating retarder (ar) functions
ar = (retardance, orientation, transmittance)
Expand All @@ -56,7 +66,9 @@ def test_stokes_recon():
tt.assert_close(torch.tensor(ar[i]), ar1[i])

# Test attenuating depolarizing retarder (adr) functions
for depolarization in torch.arange(1e-3, 1, 0.1):
for depolarization in torch.arange(
1e-3, 1, 0.1, device=device
):
adr = (
retardance,
orientation,
Expand Down Expand Up @@ -109,22 +121,24 @@ def test_mueller_from_stokes():
tt.assert_close(torch.linalg.inv(M2), M2.T)


def test_mmul():
M = torch.ones((3, 2, 1))
x = torch.ones((2, 1))

@pytest.mark.parametrize(*_DEVICE)
def test_mmul(device):
M = torch.ones((3, 2, 1), device=device)
x = torch.ones((2, 1), device=device)
y = stokes.mmul(M, x) # should pass

assert y.shape == (3, 1)
assert y.device.type == device
with pytest.raises(ValueError):
M2 = torch.ones((3, 4, 1))
y2 = stokes.mmul(M2, x)


def test_copying():
a = torch.tensor([1, 1])
b = torch.tensor([1, 1])
c = torch.tensor([1, 1])
d = torch.tensor([1, 1])
@pytest.mark.parametrize(*_DEVICE)
def test_copying(device):
a = torch.tensor([1, 1], device=device)
b = torch.tensor([1, 1], device=device)
c = torch.tensor([1, 1], device=device)
d = torch.tensor([1, 1], device=device)
s0, s1, s2, s3 = stokes.stokes_after_adr(a, b, c, d)
s0[0] = 2 # modify the output
assert c[0] == 1 # check that the input hasn't changed
Expand All @@ -134,14 +148,19 @@ def test_copying():
assert a[0] == 1


def test_orientation_offset():
@pytest.mark.parametrize(*_DEVICE)
def test_orientation_offset(device):
ori = torch.tensor(
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, torch.pi]
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, torch.pi],
device=device,
)

ff = stokes.apply_orientation_offset(ori, rotate=False, flip=False)
assert torch.allclose(
ff, torch.tensor([0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, 0])
ff,
torch.tensor(
[0, torch.pi / 4, torch.pi / 2, torch.pi - 0.01, 0], device=device
),
)

tf = stokes.apply_orientation_offset(ori, rotate=True, flip=False)
Expand All @@ -154,26 +173,32 @@ def test_orientation_offset():
0,
(torch.pi / 2) - 0.01,
torch.pi / 2,
]
],
device=device,
),
)

ft = stokes.apply_orientation_offset(ori, rotate=False, flip=True)
assert torch.allclose(
ft,
torch.tensor([0, 3 * torch.pi / 4, torch.pi / 2, 0.01, 0]),
torch.tensor(
[0, 3 * torch.pi / 4, torch.pi / 2, 0.01, 0], device=device
),
)

tt = stokes.apply_orientation_offset(ori, rotate=True, flip=True)
rotated_fliped = stokes.apply_orientation_offset(
ori, rotate=True, flip=True
)
assert torch.allclose(
tt,
rotated_fliped,
torch.tensor(
[
torch.pi / 2,
torch.pi / 4,
0,
(torch.pi / 2) + 0.01,
torch.pi / 2,
]
],
device=device,
),
)
107 changes: 107 additions & 0 deletions waveorder/correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Background correction methods"""

import torch
import torch.nn.functional as F
from torch import Tensor, Size


def _sample_block_medians(image: Tensor, block_size) -> Tensor:
"""
Sample densely tiled square blocks from a 2D image and return their medians.
Incomplete blocks (overhangs) will be ignored.

Parameters
----------
image : Tensor
2D image
block_size : int, optional
Width and height of the blocks

Returns
-------
Tensor
Median intensity values for each block, flattened
"""
if not image.dtype.is_floating_point:
image.to(torch.float)
blocks = F.unfold(image[None, None], block_size, stride=block_size)[0]
return blocks.median(0)[0]


def _grid_coordinates(image: Tensor, block_size: int) -> Tensor:
"""Build image coordinates from the center points of square blocks"""
coords = torch.meshgrid(
[
torch.arange(
0 + block_size / 2,
boundary - block_size / 2 + 1,
block_size,
device=image.device,
)
for boundary in image.shape
]
)
return torch.stack(coords, dim=-1).reshape(-1, 2)


def _fit_2d_polynomial_surface(
coords: Tensor, values: Tensor, order: int, surface_shape: Size
) -> Tensor:
"""Fit a 2D polynomial to a set of coordinates and their values,
and return the surface evaluated at every point."""
n_coeffs = int((order + 1) * (order + 2) / 2)
if n_coeffs >= len(values):
raise ValueError(
f"Cannot fit a {order} degree 2D polynomial "
f"with {len(values)} sampled values"
)
orders = torch.arange(order + 1, device=coords.device)
order_pairs = torch.stack(torch.meshgrid(orders, orders), -1)
order_pairs = order_pairs[order_pairs.sum(-1) <= order].reshape(-1, 2)
terms = torch.stack(
[coords[:, 0] ** i * coords[:, 1] ** j for i, j in order_pairs], -1
)
# use "gels" driver for precision and GPU consistency
coeffs = torch.linalg.lstsq(terms, values, driver="gels").solution
dense_coords = torch.meshgrid(
[
torch.arange(s, dtype=values.dtype, device=values.device)
for s in surface_shape
]
)
dense_terms = torch.stack(
[dense_coords[0] ** i * dense_coords[1] ** j for i, j in order_pairs],
-1,
)
return torch.matmul(dense_terms, coeffs)


def estimate_background(image: Tensor, order: int = 2, block_size: int = 32):
"""
Combine sampling and polynomial surface fit for background estimation.
To background correct an image, divide it by the background.

Parameters
----------
image : Tensor
2D image
order : int, optional
Order of polynomial, by default 2
block_size : int, optional
Width and height of the blocks, by default 32

Returns
-------
Tensor
Background image
"""
if image.ndim != 2:
raise ValueError(f"Image must be 2D, got shape {image.shape}")
height, width = image.shape
if block_size > width:
raise ValueError("Block size larger than image height")
if block_size > height:
raise ValueError("Block size larger than image width")
medians = _sample_block_medians(image, block_size)
coords = _grid_coordinates(image, block_size)
return _fit_2d_polynomial_surface(coords, medians, order, image.shape)
8 changes: 3 additions & 5 deletions waveorder/models/inplane_oriented_thick_pol3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor

from waveorder import background_estimator, stokes, util
from waveorder import correction, stokes, util


def generate_test_phantom(yx_shape):
Expand Down Expand Up @@ -125,7 +125,6 @@ def apply_inverse_transfer_function(

# Apply an "Estimated" background correction
if remove_estimated_background:
estimator = background_estimator.BackgroundEstimator2D()
for stokes_index in range(background_corrected_stokes.shape[0]):
# Project to 2D
z_projection = torch.mean(
Expand All @@ -134,9 +133,8 @@ def apply_inverse_transfer_function(
# Estimate the background and subtract
background_corrected_stokes[
stokes_index
] -= estimator.get_background(
z_projection,
normalize=False,
] -= correction.estimate_background(
z_projection, order=2, block_size=32
)

# Project to 2D (typically for SNR reasons)
Expand Down
Loading
Loading