From a65b36824804137029916a8fa722ed909b7142fe Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:16:44 -0800 Subject: [PATCH] reimplement background estimation in torch --- waveorder/correction.py | 100 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 waveorder/correction.py diff --git a/waveorder/correction.py b/waveorder/correction.py new file mode 100644 index 0000000..7261e5a --- /dev/null +++ b/waveorder/correction.py @@ -0,0 +1,100 @@ +"""Background correction methods""" + +import torch +import torch.nn.functional as F +from torch import Tensor, Size + + +def _sample_block_medians(image: Tensor, block_size) -> Tensor: + """ + Sample densely tiled square blocks from a 2D image and return their medians. + Incomplete blocks (overhangs) will be ignored. + + Parameters + ---------- + image : Tensor + 2D image + block_size : int, optional + Width and height of the blocks + + Returns + ------- + Tensor + Median intensity values for each block, flattened + """ + blocks = F.unfold(image[None, None], block_size, stride=block_size)[0] + return blocks.median(0)[0] + + +def _grid_coordinates(image: Tensor, block_size: int) -> Tensor: + """Build image coordinates from the center points of square blocks""" + coords = torch.meshgrid( + [ + torch.arange( + 0 + block_size / 2, + boundary - block_size / 2 + 1, + block_size, + device=image.device, + ) + for boundary in image.shape + ] + ) + return torch.stack(coords, dim=-1).reshape(-1, 2) + + +def _fit_2d_polynomial_surface( + coords: Tensor, values: Tensor, order: int, surface_shape: Size +) -> Tensor: + """Fit a 2D polynomial to a set of coordinates and their values, + and return the surface evaluated at every point.""" + n_coeffs = int((order + 1) * (order + 2) / 2) + if n_coeffs >= len(values): + raise ValueError( + f"Cannot fit a {order} degree 2D polynomial " + f"with {len(values)} sampled values" + ) + orders = torch.arange(order + 1, device=coords.device) + order_pairs = torch.stack(torch.meshgrid(orders, orders), -1) + order_pairs = order_pairs[order_pairs.sum(-1) <= order].reshape(-1, 2) + terms = torch.stack( + [coords[:, 0] ** i * coords[:, 1] ** j for i, j in order_pairs], -1 + ) + # use "gels" driver for precision and GPU consistency + coeffs = torch.linalg.lstsq(terms, values, driver="gels").solution + dense_coords = torch.meshgrid( + [ + torch.arange(s, dtype=values.dtype, device=values.device) + for s in surface_shape + ] + ) + dense_terms = torch.stack( + [dense_coords[0] ** i * dense_coords[1] ** j for i, j in order_pairs], + -1, + ) + return torch.matmul(dense_terms, coeffs) + + +def estimate_background(image: Tensor, order: int = 2, block_size: int = 32): + """ + + + Combine sampling and polynomial surface fit for background estimation. + To background correct an image, divide it by the background. + + :param np.array im: 2D image + :param int order: Order of polynomial (default 2) + :param bool normalize: Normalize surface by dividing by its mean + for background correction (default True) + + :return np.array background: Background image + """ + if image.ndim != 2: + raise ValueError(f"Image must be 2D, got shape {image.shape}") + height, width = image.shape + if block_size > width: + raise ValueError("Block size larger than image height") + if block_size > height: + raise ValueError("Block size larger than image width") + medians = _sample_block_medians(image, block_size) + coords = _grid_coordinates(image, block_size) + return _fit_2d_polynomial_surface(coords, medians, order, image.shape)