From de8fc4a43602912b9ffb7679c09ef72d0336a0da Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Mon, 11 Sep 2023 12:44:35 +0200 Subject: [PATCH 01/15] Fixed output shape --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index dcace425b..1334afe3a 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -94,7 +94,7 @@ def run(images_md_path: str, output_images_path = str(pathlib.Path(output_md_path).with_suffix('.mrc')) output_mrc = mrcfile.new_mmap( output_images_path, - shape=(len(images_md), 1) + image_size, + shape=(len(images_md), ) + image_size, mrc_mode=2, overwrite=True ) @@ -209,4 +209,4 @@ def run(images_md_path: str, q0=args.q0, batch_size = args.batch, device_names = args.device - ) \ No newline at end of file + ) From 62b9ddece42106d5f0de80d9e85e69ef2c3a9d63 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Mon, 11 Sep 2023 13:04:53 +0200 Subject: [PATCH 02/15] Bugfix --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index 1334afe3a..a04e6fb1a 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -167,7 +167,7 @@ def run(images_md_path: str, torch.fft.irfft2(batch_images_fourier, out=batch_images) # Store the result - output_images[batch_slice,0] = batch_images.to('cpu', non_blocking=True) + output_images[batch_slice] = batch_images.to('cpu', non_blocking=True) # Prepare for the next batch utils.progress_bar(end, len(images_md)) From 0dd47b08b7644d4bf73112fb936598d2a67c59ea Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Thu, 21 Sep 2023 12:49:26 +0200 Subject: [PATCH 03/15] Reseting index --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index a04e6fb1a..1cbf5a8a0 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -78,6 +78,7 @@ def run(images_md_path: str, # Read input files images_md = md.sort_by_image_filename(md.read(images_md_path)) + images_md.reset_index(drop=True, inplace=True) images_paths = list(map(image.parse_path, images_md[md.IMAGE])) images_dataset = image.torch_utils.Dataset(images_paths) images_loader = torch.utils.data.DataLoader( From d7a8b9b6fee5458f2cb8e9d888f0bc183a677f08 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Thu, 12 Oct 2023 17:09:25 +0000 Subject: [PATCH 04/15] Fixed synchronization issue --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index 1cbf5a8a0..f7ad57712 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -176,6 +176,11 @@ def run(images_md_path: str, assert(end == len(images_md)) + # Wait for transfers to finish + if transform_device.type == 'cuda': + torch.cuda.synchronize(transform_device) + + # Update metadata images_md[md.IMAGE] = (images_md.index + 1).map(('{:06d}@' + output_images_path).format) md.write(images_md, output_md_path) From 98a45fad1671e7eca172b7ef3528552fbfd5df6a Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Mon, 26 Feb 2024 14:27:48 +0000 Subject: [PATCH 05/15] Fixing merge error --- .../py_xmipp/swiftalign/fourier/__init__.py | 27 +------------------ 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py index 3f8426cc9..da3698668 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py @@ -20,29 +20,4 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from typing import Optional, Sequence -import torch - -def wiener_2d( direct_filter: torch.Tensor, - inverse_ssnr: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None) -> torch.Tensor: - - # Compute the filter power (|H|²) at the output - if torch.is_complex(direct_filter): - out = torch.abs(direct_filter, out=out) - out.square_() - else: - out = torch.square(direct_filter, out=out) - - # Compute the default value for inverse SSNR if not - # provided - if inverse_ssnr is None: - inverse_ssnr = torch.mean(out, dim=(-2, -1)) - inverse_ssnr *= 0.1 - inverse_ssnr = inverse_ssnr[...,None,None] - - # H* / (|H|² + N/S) - out.add_(inverse_ssnr) - torch.div(torch.conj(direct_filter), out, out=out) - - return out +from .rfftnfreq import rfftnfreq \ No newline at end of file From 6908f9dd6485be31b6f7b1ac09e5342bd031a425 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Mon, 26 Feb 2024 14:40:19 +0000 Subject: [PATCH 06/15] Using a CTF descriptor class --- .../swiftalign_wiener_2d.py | 17 ++++-- .../py_xmipp/swiftalign/ctf/__init__.py | 2 +- .../swiftalign/ctf/compute_ctf_image_2d.py | 58 ++++++++++--------- 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index f7ad57712..54526215d 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -138,18 +138,23 @@ def run(images_md_path: str, batch_images_fourier.resize_(0) # Force explicit reuse batch_images_fourier = torch.fft.rfft2(batch_images, out=batch_images_fourier) + # Elaborate the CTF descriptor + ctf_desc = ctf.Ctf2dDesc( + wavelength=wavelength, + spherical_aberration=spherical_aberration, + defocus_average=defocus[:,0], + defocus_difference=defocus[:,1], + astigmatism_angle=defocus[:,2], + q0=q0 + ) + # Compute the CTF image if ctf_images is not None: ctf_images.resize_(0) # Force explicit reuse ctf_images = ctf.compute_ctf_image_2d( frequency_magnitude2_grid=polar_frequency_grid[0], frequency_angle_grid=polar_frequency_grid[1], - defocus_average=defocus[:,0], - defocus_difference=defocus[:,1], - astigmatism_angle=defocus[:,2], - wavelength=wavelength, - spherical_aberration=spherical_aberration, - q0=q0, + ctf_desc=ctf_desc, out=ctf_images ) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py index cec32742c..8c7dfd60a 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py @@ -20,5 +20,5 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from .compute_ctf_image_2d import compute_ctf_image_2d +from .compute_ctf_image_2d import Ctf2dDesc, compute_ctf_image_2d from .wiener import wiener_2d \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py index 336abf87d..4aa7b0aa0 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py @@ -20,10 +20,23 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from typing import Optional +from typing import Optional, NamedTuple import torch import math +class Ctf2dDesc(NamedTuple): + wavelength: float + spherical_aberration: float + defocus_average: torch.Tensor + defocus_difference: torch.Tensor + astigmatism_angle: torch.Tensor + q0: Optional[float] = None + chromatic_aberration: Optional[torch.Tensor] = None + energy_spread_coefficient: Optional[float] = None + lens_inestability_coefficient: Optional[float] = None + phase_shift: Optional[float] = None + + def _compute_defocus_grid_2d(frequency_angle_grid: torch.Tensor, defocus_average: torch.Tensor, defocus_difference: torch.Tensor, @@ -59,54 +72,45 @@ def _compute_beam_energy_spread(frequency_magnitude2_grid: torch.Tensor, def compute_ctf_image_2d(frequency_magnitude2_grid: torch.Tensor, frequency_angle_grid: torch.Tensor, - defocus_average: torch.Tensor, - defocus_difference: torch.Tensor, - astigmatism_angle: torch.Tensor, - wavelength: float, - spherical_aberration: float, - q0: Optional[float] = None, - chromatic_aberration: Optional[torch.Tensor] = None, - energy_spread_coefficient: Optional[float] = None, - lens_inestability_coefficient: Optional[float] = None, - phase_shift: Optional[float] = None, + ctf_desc: Ctf2dDesc, out: Optional[torch.Tensor] = None ) -> torch.Tensor: - k = 0.5 * spherical_aberration * wavelength * wavelength + k = 0.5 * ctf_desc.spherical_aberration * ctf_desc.wavelength * ctf_desc.wavelength out = _compute_defocus_grid_2d( frequency_angle_grid=frequency_angle_grid, - defocus_average=defocus_average, - defocus_difference=defocus_difference, - astigmatism_angle=astigmatism_angle, + defocus_average=ctf_desc.defocus_average, + defocus_difference=ctf_desc.defocus_difference, + astigmatism_angle=ctf_desc.astigmatism_angle, out=out ) # Compute the phase out -= k*frequency_magnitude2_grid - out *= (torch.pi * wavelength) * frequency_magnitude2_grid + out *= (torch.pi * ctf_desc.wavelength) * frequency_magnitude2_grid # Apply the phase shift if provided - if phase_shift is not None: - out += phase_shift + if ctf_desc.phase_shift is not None: + out += ctf_desc.phase_shift # Compute the sin, also considering the inelastic # difraction factor if provided - if q0 is not None: - out = out.sin() + q0*out.cos() + if ctf_desc.q0 is not None: + out = out.sin() + ctf_desc.q0*out.cos() else: out.sin_() # Apply energy spread envelope - if (chromatic_aberration is not None) and \ - (energy_spread_coefficient is not None) and \ - (lens_inestability_coefficient is not None): + if (ctf_desc.chromatic_aberration is not None) and \ + (ctf_desc.energy_spread_coefficient is not None) and \ + (ctf_desc.lens_inestability_coefficient is not None): beam_energy_spread = _compute_beam_energy_spread( frequency_magnitude2_grid=frequency_magnitude2_grid, - chromatic_aberration=chromatic_aberration, - wavelength=wavelength, - energy_spread_coefficient=energy_spread_coefficient, - lens_inestability_coefficient=lens_inestability_coefficient + chromatic_aberration=ctf_desc.chromatic_aberration, + wavelength=ctf_desc.wavelength, + energy_spread_coefficient=ctf_desc.energy_spread_coefficient, + lens_inestability_coefficient=ctf_desc.lens_inestability_coefficient ) out *= beam_energy_spread From 41b9535fd2c5ba676d38aa9d399310919c6e4bc0 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Mon, 26 Feb 2024 17:21:52 +0000 Subject: [PATCH 07/15] Remove unused envelope --- .../swiftalign/ctf/compute_ctf_image_2d.py | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py index 4aa7b0aa0..e523cfd39 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py @@ -31,12 +31,8 @@ class Ctf2dDesc(NamedTuple): defocus_difference: torch.Tensor astigmatism_angle: torch.Tensor q0: Optional[float] = None - chromatic_aberration: Optional[torch.Tensor] = None - energy_spread_coefficient: Optional[float] = None - lens_inestability_coefficient: Optional[float] = None phase_shift: Optional[float] = None - def _compute_defocus_grid_2d(frequency_angle_grid: torch.Tensor, defocus_average: torch.Tensor, defocus_difference: torch.Tensor, @@ -51,25 +47,6 @@ def _compute_defocus_grid_2d(frequency_angle_grid: torch.Tensor, return out -def _compute_beam_energy_spread(frequency_magnitude2_grid: torch.Tensor, - chromatic_aberration: torch.Tensor, - wavelength: float, - energy_spread_coefficient: float, - lens_inestability_coefficient: float, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - # http://i2pc.es/coss/Articulos/Sorzano2007a.pdf - # Equation 10 - k = torch.pi / 4 * wavelength * (energy_spread_coefficient + 2*lens_inestability_coefficient) - x = chromatic_aberration * k - x.square_() - x *= -1.0 / math.log(2) - - out = torch.mul(x[...,None,None], frequency_magnitude2_grid.square(), out=out) - out.exp_() - - return out - def compute_ctf_image_2d(frequency_magnitude2_grid: torch.Tensor, frequency_angle_grid: torch.Tensor, ctf_desc: Ctf2dDesc, @@ -100,21 +77,5 @@ def compute_ctf_image_2d(frequency_magnitude2_grid: torch.Tensor, else: out.sin_() - # Apply energy spread envelope - if (ctf_desc.chromatic_aberration is not None) and \ - (ctf_desc.energy_spread_coefficient is not None) and \ - (ctf_desc.lens_inestability_coefficient is not None): - - beam_energy_spread = _compute_beam_energy_spread( - frequency_magnitude2_grid=frequency_magnitude2_grid, - chromatic_aberration=ctf_desc.chromatic_aberration, - wavelength=ctf_desc.wavelength, - energy_spread_coefficient=ctf_desc.energy_spread_coefficient, - lens_inestability_coefficient=ctf_desc.lens_inestability_coefficient - ) - out *= beam_energy_spread - - - return out \ No newline at end of file From f26f5c6cd165225c16e09df4771a1838d1e1c58f Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Wed, 28 Feb 2024 08:54:00 +0000 Subject: [PATCH 08/15] Added missing complimentary Q0 --- .../libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py index e523cfd39..b71d3ca4b 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py @@ -73,7 +73,9 @@ def compute_ctf_image_2d(frequency_magnitude2_grid: torch.Tensor, # Compute the sin, also considering the inelastic # difraction factor if provided if ctf_desc.q0 is not None: - out = out.sin() + ctf_desc.q0*out.cos() + cos_q0 = ctf_desc.q0 + sin_q0 = math.sqrt(1.0 - cos_q0**2) + out = sin_q0*out.sin() + cos_q0*out.cos() else: out.sin_() From 1ae2b8fc380cbd93ba10764a6b2178d1dd5ac0e5 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Wed, 28 Feb 2024 10:59:53 +0000 Subject: [PATCH 09/15] Added padding option --- .../swiftalign_wiener_2d.py | 32 ++++++++--- .../py_xmipp/swiftalign/fourier/__init__.py | 3 +- .../py_xmipp/swiftalign/fourier/zero_pad.py | 55 +++++++++++++++++++ 3 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/fourier/zero_pad.py diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index 54526215d..158be8438 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -65,6 +65,7 @@ def run(images_md_path: str, voltage: float, phase_flipped: bool, q0: float, + padding: int, batch_size: int, device_names: list ): @@ -90,6 +91,7 @@ def run(images_md_path: str, num_workers=1 ) image_size = md.get_image2d_size(images_md) + padded_size = (image_size[0]*padding, image_size[1]*padding) # Create a MMAPed output file output_images_path = str(pathlib.Path(output_md_path).with_suffix('.mrc')) @@ -109,7 +111,7 @@ def run(images_md_path: str, # Compute frequency grid cartesian_frequency_grid = fourier.rfftnfreq( - image_size, + padded_size, d=pixel_size, device=transform_device ) @@ -120,6 +122,7 @@ def run(images_md_path: str, batch_images_fourier = None ctf_images = None wiener_filters = None + padded_images = None for batch_images in images_loader: batch_images = batch_images[0] # Due to the BatchSampler end = start + len(batch_images) @@ -127,16 +130,25 @@ def run(images_md_path: str, batch_slice = slice(start, end) batch_images_md = images_md.iloc[batch_slice] + # Obtain defocus defocus = torch.from_numpy(batch_images_md[[md.CTF_DEFOCUS_U, md.CTF_DEFOCUS_V, md.CTF_DEFOCUS_ANGLE]].to_numpy()) _compute_differential_defocus_inplace(defocus[:,:2]) defocus[:,2].deg2rad_() defocus = defocus.to(transform_device, non_blocking=True) + # Zero pad images if necessary + padded_images = fourier.zero_pad( + batch_images, + dim=(-2, -1), + factor=padding, + out=padded_images + ) + # Perform the FFT of the images if batch_images_fourier is not None: batch_images_fourier.resize_(0) # Force explicit reuse - batch_images_fourier = torch.fft.rfft2(batch_images, out=batch_images_fourier) + batch_images_fourier = torch.fft.rfft2(padded_images, out=batch_images_fourier) # Elaborate the CTF descriptor ctf_desc = ctf.Ctf2dDesc( @@ -170,11 +182,15 @@ def run(images_md_path: str, batch_images_fourier *= wiener_filters # Perform the inverse FFT computaion - torch.fft.irfft2(batch_images_fourier, out=batch_images) - - # Store the result - output_images[batch_slice] = batch_images.to('cpu', non_blocking=True) + torch.fft.irfft2(batch_images_fourier, out=padded_images) + # Undo padding and store + if padded_images is batch_images: + output_images[batch_slice] = batch_images.to('cpu', non_blocking=True) + else: + read_slice = tuple(map(slice, batch_images.shape)) + output_images[batch_slice] = padded_images[read_slice].to('cpu', non_blocking=True) + # Prepare for the next batch utils.progress_bar(end, len(images_md)) start = end @@ -203,6 +219,7 @@ def run(images_md_path: str, parser.add_argument('--voltage', type=float, required=True) parser.add_argument('--q0', type=float, default=0.1) parser.add_argument('--phase_flipped', action='store_true') + parser.add_argument('--padding', type=int, default=1) parser.add_argument('--batch', type=int, default=1024) parser.add_argument('--device', nargs='*') @@ -216,8 +233,9 @@ def run(images_md_path: str, pixel_size=args.pixel_size, spherical_aberration=args.spherical_aberration, voltage=args.voltage, - phase_flipped=args.phase_flipped, q0=args.q0, + phase_flipped=args.phase_flipped, + padding=args.padding, batch_size = args.batch, device_names = args.device ) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py index da3698668..e0fdbd48d 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py @@ -20,4 +20,5 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from .rfftnfreq import rfftnfreq \ No newline at end of file +from .rfftnfreq import rfftnfreq +from .zero_pad import zero_pad \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/zero_pad.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/zero_pad.py new file mode 100644 index 000000000..cffe8c4c1 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/zero_pad.py @@ -0,0 +1,55 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Sequence, Optional +import torch + +def _compute_padded_shape(shape: torch.Size, dim: Sequence[int], factor: int) -> torch.Size: + shape_as_list = list(shape) # To mutate + for d in dim: + shape_as_list[d] *= factor + return torch.Size(shape_as_list) + +def zero_pad(x: torch.Tensor, + dim: Sequence[int], + factor: int, + copy: bool = False, + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + if factor > 1 or copy: + padded_size = _compute_padded_shape(x.shape, dim=dim, factor=factor) + out = torch.zeros( + size=padded_size, + dtype=x.dtype, + device=x.device, + out=out + ) + + # Write + write_slice = tuple(map(slice, x.shape)) + out[write_slice] = x + + else: + out = x + + return out + + From 6b9680aed4e8443a4931dfbb038ce7d4077f450c Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Wed, 28 Feb 2024 11:56:59 +0000 Subject: [PATCH 10/15] Removing extra line --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index 158be8438..e25b7f38e 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -201,7 +201,6 @@ def run(images_md_path: str, if transform_device.type == 'cuda': torch.cuda.synchronize(transform_device) - # Update metadata images_md[md.IMAGE] = (images_md.index + 1).map(('{:06d}@' + output_images_path).format) md.write(images_md, output_md_path) From 4a25152758711ad850b2f8f293f38ea26660b449 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Thu, 7 Mar 2024 10:31:30 +0000 Subject: [PATCH 11/15] Trying to fix GPU sync issues --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index e25b7f38e..f1cc4f18f 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -186,10 +186,10 @@ def run(images_md_path: str, # Undo padding and store if padded_images is batch_images: - output_images[batch_slice] = batch_images.to('cpu', non_blocking=True) + output_images[batch_slice] = batch_images.to('cpu') else: read_slice = tuple(map(slice, batch_images.shape)) - output_images[batch_slice] = padded_images[read_slice].to('cpu', non_blocking=True) + output_images[batch_slice] = padded_images[read_slice].to('cpu') # Prepare for the next batch utils.progress_bar(end, len(images_md)) @@ -197,10 +197,6 @@ def run(images_md_path: str, assert(end == len(images_md)) - # Wait for transfers to finish - if transform_device.type == 'cuda': - torch.cuda.synchronize(transform_device) - # Update metadata images_md[md.IMAGE] = (images_md.index + 1).map(('{:06d}@' + output_images_path).format) md.write(images_md, output_md_path) From 72d9d1dcb23d7b452b7ee2825e805ff066038882 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Tue, 16 Apr 2024 08:21:02 +0000 Subject: [PATCH 12/15] Removed non_blocking operations --- .../scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py index f1cc4f18f..5b5bf414e 100755 --- a/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py +++ b/src/xmipp/applications/scripts/swiftalign_wiener_2d/swiftalign_wiener_2d.py @@ -126,7 +126,7 @@ def run(images_md_path: str, for batch_images in images_loader: batch_images = batch_images[0] # Due to the BatchSampler end = start + len(batch_images) - batch_images: torch.Tensor = batch_images.to(transform_device, non_blocking=True) + batch_images: torch.Tensor = batch_images.to(transform_device) batch_slice = slice(start, end) batch_images_md = images_md.iloc[batch_slice] @@ -135,7 +135,7 @@ def run(images_md_path: str, defocus = torch.from_numpy(batch_images_md[[md.CTF_DEFOCUS_U, md.CTF_DEFOCUS_V, md.CTF_DEFOCUS_ANGLE]].to_numpy()) _compute_differential_defocus_inplace(defocus[:,:2]) defocus[:,2].deg2rad_() - defocus = defocus.to(transform_device, non_blocking=True) + defocus = defocus.to(transform_device) # Zero pad images if necessary padded_images = fourier.zero_pad( From 55022da6ffb694bf7fdb58b62b8340e77a0b2af3 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Tue, 14 May 2024 15:26:30 +0000 Subject: [PATCH 13/15] Re-added lost conda env --- .../bindings/python/envs_DLTK/xmipp_graph.yml | 13 ------------- .../python/envs_DLTK/xmipp_swiftalign-gpu.yml | 19 +++++++++++++++++++ .../python/envs_DLTK/xmipp_swiftalign.yml | 17 +++++++++++++++++ src/xmipp/bindings/python/xmipp_conda_envs.py | 4 ++-- 4 files changed, 38 insertions(+), 15 deletions(-) delete mode 100644 src/xmipp/bindings/python/envs_DLTK/xmipp_graph.yml create mode 100644 src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign-gpu.yml create mode 100644 src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign.yml diff --git a/src/xmipp/bindings/python/envs_DLTK/xmipp_graph.yml b/src/xmipp/bindings/python/envs_DLTK/xmipp_graph.yml deleted file mode 100644 index d292d5738..000000000 --- a/src/xmipp/bindings/python/envs_DLTK/xmipp_graph.yml +++ /dev/null @@ -1,13 +0,0 @@ -#Protocols ussing the enviroment: -#xmipp_graph_max_cut - -name: xmipp_graph -channels: - - anaconda - - conda-forge - - defaults -dependencies: - - python=3.8 - - scipy=1.10 - - networkx=2.8 - - cvxpy=1.3 \ No newline at end of file diff --git a/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign-gpu.yml b/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign-gpu.yml new file mode 100644 index 000000000..348ebb35d --- /dev/null +++ b/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign-gpu.yml @@ -0,0 +1,19 @@ +#Protocols ussing the enviroment: protocol_deepHand + +name: xmipp_swiftalign +channels: + - anaconda + - defaults + - conda-forge + - pytorch + - nvidia +dependencies: + - python=3.10 + - numpy=1.23 + - mrcfile=1.4.3 + - kornia=0.6.8 #The last that is provided as *.bz2 extension + - starfile=0.4.11 #The last that that is provided as *.bz2 extension + - pytorch==1.13.1 + - pytorch-cuda=11.7 + - torchvision=0.14 + - faiss-gpu=1.8.0 diff --git a/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign.yml b/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign.yml new file mode 100644 index 000000000..f9e1afc7b --- /dev/null +++ b/src/xmipp/bindings/python/envs_DLTK/xmipp_swiftalign.yml @@ -0,0 +1,17 @@ +#Protocols ussing the enviroment: protocol_deepHand + +name: xmipp_pyTorch +channels: + - anaconda + - defaults + - conda-forge + - pytorch +dependencies: + - python=3.10 + - numpy=1.23 + - mrcfile=1.4.3 + - kornia=0.6.8 #The last that that is provided as *.bz2 extension + - starfile=0.4.11 #The last that that is provided as *.bz2 extension + - pytorch==1.13.1 + - torchvision=0.14 + - faiss-cpu=1.8.0 diff --git a/src/xmipp/bindings/python/xmipp_conda_envs.py b/src/xmipp/bindings/python/xmipp_conda_envs.py index 71f4a950f..2ab86fb65 100644 --- a/src/xmipp/bindings/python/xmipp_conda_envs.py +++ b/src/xmipp/bindings/python/xmipp_conda_envs.py @@ -28,8 +28,8 @@ "xmippEnviron": True }, - "xmipp_graph": { - "requirements": os.path.join(_REQUIREMENT_PATH, 'xmipp_graph.yml'), + "xmipp_swiftalign": { + "requirements": os.path.join(_REQUIREMENT_PATH, 'xmipp_swiftalign.yml'), "xmippEnviron": True }, } From b6accaafda14f41f7bffe5afbe3488f0ba776fc3 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Tue, 14 May 2024 16:13:23 +0000 Subject: [PATCH 14/15] Revert "Removed unused swiftalign code" This reverts commit a74caa97652674955cbd574a1ac71f865df4a059. --- .../libraries/py_xmipp/swiftalign/__init__.py | 8 +- .../FourierInPlaneTransformAugmenter.py | 81 +++++++++ .../FourierInPlaneTransformCorrector.py | 111 ++++++++++++ .../FourierInPlaneTransformGenerator.py | 141 +++++++++++++++ .../alignment/InPlaneTransformBatch.py | 30 ++++ .../py_xmipp/swiftalign/alignment/__init__.py | 8 + .../py_xmipp/swiftalign/alignment/align.py | 52 ++++++ .../alignment/generate_alignment_metadata.py | 166 ++++++++++++++++++ .../py_xmipp/swiftalign/alignment/populate.py | 71 ++++++++ .../py_xmipp/swiftalign/alignment/train.py | 45 +++++ .../py_xmipp/swiftalign/dct/__init__.py | 25 +++ .../py_xmipp/swiftalign/dct/basis.py | 60 +++++++ .../libraries/py_xmipp/swiftalign/dct/dct.py | 48 +++++ .../py_xmipp/swiftalign/dct/project.py | 70 ++++++++ .../py_xmipp/swiftalign/fourier/__init__.py | 4 +- .../fourier/remove_symmetic_half.py | 36 ++++ .../py_xmipp/swiftalign/fourier/rfftnfreq.py | 3 +- .../swiftalign/fourier/time_shift_filter.py | 45 +++++ .../py_xmipp/swiftalign/math/__init__.py | 26 +++ .../swiftalign/math/complex_normalize.py | 35 ++++ .../swiftalign/math/flat_view_as_real.py | 28 +++ .../py_xmipp/swiftalign/math/l2_normalize.py | 36 ++++ .../swiftalign/math/mu_sigma_normalize.py | 38 ++++ .../operators/DctLowPassFlattener.py | 58 ++++++ .../swiftalign/operators/DctTransformer2D.py | 47 +++++ .../operators/FourierLowPassFlattener.py | 62 +++++++ .../operators/FourierShiftFilter.py | 64 +++++++ .../operators/FourierTransformer2D.py | 40 +++++ .../swiftalign/operators/ImageRotator.py | 51 ++++++ .../swiftalign/operators/ImageShifter.py | 53 ++++++ .../operators/ImageSpectraFlattener.py | 74 ++++++++ .../swiftalign/operators/SpectraFlattener.py | 33 ++++ .../swiftalign/operators/Transformer2D.py | 33 ++++ .../py_xmipp/swiftalign/operators/__init__.py | 11 +- .../py_xmipp/swiftalign/search/Database.py | 81 +++++++++ .../py_xmipp/swiftalign/search/Faiss.py | 145 +++++++++++++++ .../py_xmipp/swiftalign/search/MedianHash.py | 139 +++++++++++++++ .../py_xmipp/swiftalign/search/__init__.py | 25 +++ 38 files changed, 2076 insertions(+), 7 deletions(-) create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py create mode 100644 src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py index 9c776e1f4..aaff76b53 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py @@ -20,12 +20,14 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from . import ctf -from . import fourier from . import alignment from . import classification +from . import ctf +from . import dct +from . import fourier from . import image from . import metadata from . import operators +from . import search from . import transform -from . import utils \ No newline at end of file +from . import utils diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py new file mode 100644 index 000000000..dba34f6a7 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py @@ -0,0 +1,81 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Iterable, Optional +import torch +import torchvision.transforms as T + +import random + +from .. import operators +from .. import math + +class FourierInPlaneTransformAugmenter: + def __init__(self, + max_psi: float, + max_shift: float, + flattener: operators.SpectraFlattener, + ctfs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + norm: Optional[str] = None, + interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR ) -> None: + + # Random affine transformer + self.random_affine = T.RandomAffine( + degrees=max_psi, + translate=(max_shift, )*2, + interpolation=interpolation + ) + + # Operations + self.fourier = operators.FourierTransformer2D() + self.flattener = flattener + self.ctfs = ctfs + self.weights = weights + self.norm = norm + + def __call__(self, + images: Iterable[torch.Tensor], + times: int = 1) -> torch.Tensor: + + if self.norm: + raise NotImplementedError('Normalization is not implemented') + + images_affine = None + images_fourier_transform = None + images_band = None + + for batch in images: + for _ in range(times): + images_affine = self.random_affine(batch) + images_fourier_transform = self.fourier(images_affine, out=images_fourier_transform) + images_band = self.flattener(images_fourier_transform, out=images_band) + + if self.weights is not None: + images_band *= self.weights + + if self.ctfs is not None: + # Select a random CTF and apply it + ctf = self.ctfs[random.randrange(len(self.ctfs))] + images_band *= ctf + + yield math.flat_view_as_real(images_band) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py new file mode 100644 index 000000000..1b36d21b6 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py @@ -0,0 +1,111 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Iterable, Optional, Tuple +import pandas as pd +import torch + +from .. import operators +from .. import transform +from .. import math +from .. import metadata as md + +class FourierInPlaneTransformCorrector: + def __init__(self, + flattener: operators.SpectraFlattener, + weights: Optional[torch.Tensor] = None, + norm: Optional[str] = None, + interpolation: str = 'bilinear', + device: Optional[torch.device] = None ) -> None: + + # Operations + self.fourier = operators.FourierTransformer2D() + self.flattener = flattener + self.weights = weights + self.interpolation = interpolation + self.norm = norm + + # Cached variables + self.rotated_images = None + self.rotated_fourier_transform = None + self.rotated_band = None + self.shifted_rotated_band = None + + # Device + self.device = device + + def __call__(self, + images: Iterable[Tuple[torch.Tensor, pd.DataFrame]]) -> torch.Tensor: + + if self.norm: + raise NotImplementedError('Normalization is not implemented') + + angles = None + transform_matrix = None + fourier_transforms = None + bands = None + + TRANSFORM_LABELS = [md.ANGLE_PSI, md.SHIFT_X, md.SHIFT_Y] + + for batch_images, batch_md in images: + if len(batch_images) != len(batch_md): + raise RuntimeError('Metadata and image batch sizes do not match') + + if self.device is not None: + batch_images = batch_images.to(self.device, non_blocking=True) + + if all(map(batch_md.columns.__contains__, TRANSFORM_LABELS)): + transformations = torch.as_tensor(batch_md[TRANSFORM_LABELS].to_numpy(), dtype=torch.float32) + angles = torch.deg2rad(transformations[:,0], out=angles) + shifts = transformations[:,1:] + centre = torch.tensor(batch_images.shape[-2:]) / 2 + + transform_matrix = transform.affine_matrix_2d( + angles=angles, + shifts=shifts, + centre=centre, + shift_first=True, + out=transform_matrix + ) + + batch_images = transform.affine_2d( + images=batch_images, + matrices=transform_matrix.to(batch_images, non_blocking=True), + interpolation=self.interpolation, + padding='zeros' + ) + + fourier_transforms = self.fourier( + batch_images, + out=fourier_transforms + ) + + bands = self.flattener( + fourier_transforms, + out=bands + ) + + if self.weights is not None: + bands *= self.weights + + yield math.flat_view_as_real(bands) + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py new file mode 100644 index 000000000..19d4605d0 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py @@ -0,0 +1,141 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Sequence, Iterable, Optional +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +import itertools + +from .. import operators +from .. import math +from .. import fourier +from .InPlaneTransformBatch import InPlaneTransformBatch + +def _compute_shift_filters( shifts: torch.Tensor, + flattener: operators.SpectraFlattener, + dim: Sequence[int], + device: Optional[torch.device] = None ) -> torch.Tensor: + d = 1.0/(2*torch.pi) + frequency_grid = fourier.rfftnfreq(dim, d=d, device=device) + frequency_coefficients = flattener(frequency_grid) + return fourier.time_shift_filter(shifts.to(frequency_coefficients.device), frequency_coefficients) + +class FourierInPlaneTransformGenerator: + def __init__(self, + dim: Sequence[int], + angles: torch.Tensor, + shifts: torch.Tensor, + flattener: operators.SpectraFlattener, + ctfs: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + norm: Optional[str] = None, + interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + device: Optional[torch.device] = None ) -> None: + + # Transforms + self.angles = angles + self.shifts = shifts + self.shift_filters = _compute_shift_filters(-shifts, flattener, dim, device) # Invert shifts + + # Operations + self.fourier = operators.FourierTransformer2D() + self.flattener = flattener + self.ctfs = ctfs + self.weights = weights + self.interpolation = interpolation + self.norm = norm + + # Device + self.device = device + + def __call__(self, + images: Iterable[torch.Tensor]) -> InPlaneTransformBatch: + + if self.norm: + raise NotImplementedError('Normalization is not implemented') + + indices = None + rotated_images = None + rotated_fourier_transforms = None + rotated_bands = None + ctf_bands = None + shifted_bands = None + + start = 0 + for batch in images: + if self.device is not None: + batch = batch.to(self.device, non_blocking=True) + + end = start + len(batch) + indices = torch.arange(start=start, end=end, out=indices) + + for angle in self.angles: + rotated_images = F.rotate( + batch, + angle=float(angle), + interpolation=self.interpolation + ) + + rotated_fourier_transforms = self.fourier( + rotated_images, + out=rotated_fourier_transforms + ) + + rotated_bands = self.flattener( + rotated_fourier_transforms, + out=rotated_bands + ) + + if self.weights is not None: + rotated_bands *= self.weights + + ctfs = self.ctfs if self.ctfs is not None else itertools.repeat(None, times=1) + for ctf in ctfs: + if ctf is not None: + # Apply the CTF when provided + ctf_bands = torch.mul( + rotated_bands, + ctf, + out=ctf_bands + ) + else: + # No CTF + ctf_bands = rotated_bands + + for shift, shift_filter in zip(self.shifts, self.shift_filters): + shifted_bands = torch.mul( + ctf_bands, + shift_filter, + out=shifted_bands + ) + + yield InPlaneTransformBatch( + indices=indices, + vectors=math.flat_view_as_real(shifted_bands), + angle=float(angle), + shift=shift + ) + + # Advance the counter for the next iteration + start = end + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py new file mode 100644 index 000000000..42f1ddc06 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py @@ -0,0 +1,30 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import NamedTuple +import torch + +class InPlaneTransformBatch(NamedTuple): + indices: torch.IntTensor + vectors: torch.Tensor + shift: torch.Tensor + angle: float diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py index 5e1d064bb..99354ae2a 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py @@ -20,4 +20,12 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ +from .align import align +from .train import train +from .populate import populate +from .generate_alignment_metadata import generate_alignment_metadata + +from .FourierInPlaneTransformAugmenter import FourierInPlaneTransformAugmenter +from .FourierInPlaneTransformGenerator import FourierInPlaneTransformGenerator +from .FourierInPlaneTransformCorrector import FourierInPlaneTransformCorrector from .InPlaneTransformCorrector import InPlaneTransformCorrector \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py new file mode 100644 index 000000000..1d55c3e1e --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py @@ -0,0 +1,52 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Iterable +import torch + +from .. import search + + +def align(db: search.Database, + dataset: Iterable[torch.Tensor], + k: int ) -> search.SearchResult: + + database_device = db.get_input_device() + + index_vectors = [] + distance_vectors = [] + + for vectors in dataset: + # Search them + s = db.search(vectors.to(database_device), k=k) + + # Add them to the result + index_vectors.append(s.indices.cpu()) + distance_vectors.append(s.distances.cpu()) + + # Concatenate all result vectors + return search.SearchResult( + indices=torch.cat(index_vectors, axis=0), + distances=torch.cat(distance_vectors, axis=0) + ) + + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py new file mode 100644 index 000000000..26144be22 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py @@ -0,0 +1,166 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Sequence + +import pandas as pd +import numpy as np + +from .. import metadata as md +from .. import search + +def _ensemble_alignment_md(reference_md: pd.DataFrame, + projection_md: pd.DataFrame, + match_distances: np.ndarray, + match_indices: np.ndarray, + reference_columns: Sequence[int], + index: Optional[np.ndarray] = None, + local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: + + result = pd.DataFrame(match_distances, columns=[md.COST], index=index) + + # Left-join the projection metadata to the result + result = result.join(projection_md, on=match_indices) + + # Left-join the reference metadata to the result + result = result.join(reference_md[reference_columns], on=md.REF) + + # Drop the indexing columns + result.drop(columns=md.REF, inplace=True) + + # Accumulate transforms + if local_transform_md is not None: + result[local_transform_md.columns] += local_transform_md + + # Make psi in [-180, 180] + if md.ANGLE_PSI in local_transform_md.columns: + result.loc[result[md.ANGLE_PSI] > +180.0, md.ANGLE_PSI] -= 360.0 + result.loc[result[md.ANGLE_PSI] < -180.0, md.ANGLE_PSI] += 360.0 + + return result + +def _update_alignment_metadata(output_md: pd.DataFrame, + reference_md: pd.DataFrame, + projection_md: pd.DataFrame, + match_distances: np.ndarray, + match_indices: np.ndarray, + reference_columns: Sequence[int], + local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: + # Select the rows to be updated + selection = match_distances < output_md[md.COST].to_numpy() + + if local_transform_md is not None: + local_transform_md = local_transform_md[selection] + + # Do an alignment for the selected rows + alignment_md = _ensemble_alignment_md( + reference_md=reference_md, + projection_md=projection_md, + match_distances=match_distances[selection], + match_indices=match_indices[selection], + reference_columns=reference_columns, + index=np.nonzero(selection)[0], + local_transform_md=local_transform_md + ) + + # Update output + output_md.loc[selection, alignment_md.columns] = alignment_md + + return output_md + +def _create_alignment_metadata(experimental_md: pd.DataFrame, + reference_md: pd.DataFrame, + projection_md: pd.DataFrame, + match_distances: np.ndarray, + match_indices: np.ndarray, + reference_columns: Sequence[int], + local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: + + # Use the first match + alignment_md = _ensemble_alignment_md( + reference_md=reference_md, + projection_md=projection_md, + match_distances=match_distances, + match_indices=match_indices, + reference_columns=reference_columns, + local_transform_md=local_transform_md + ) + + # Add the alignment consensus to the output + output_md = experimental_md.drop(columns=alignment_md.columns, errors='ignore') + output_md = output_md.join(alignment_md) + + # Reorder columns for more convenient reading + output_md = output_md.reindex( + columns=experimental_md.columns.union(output_md.columns, sort=False) + ) + + return output_md + +def generate_alignment_metadata(experimental_md: pd.DataFrame, + reference_md: pd.DataFrame, + projection_md: pd.DataFrame, + matches: search.SearchResult, + reference_columns: Sequence[int], + local_transform_md: Optional[pd.DataFrame] = None, + output_md: Optional[pd.DataFrame] = None) -> pd.DataFrame: + + # Rename the reference image column to make it compatible + # with the resulting MD. (No duplicated IMAGE column) + reference_md = reference_md.rename(columns={ + md.IMAGE: md.REFERENCE_IMAGE, + }) + + # Flatten the kNN results into a single dim and provide + # them as a numpy array. Concatenate columns to match + # the concatenation of experimental md + k = matches.indices.shape[1] + match_distances = matches.distances.numpy().flatten() + match_indices = matches.indices.numpy().flatten() + + # Update or generate depending on wether the output is provided + if output_md is None: + # Repeat each row of the experimental MD k times + experimental_md = experimental_md.loc[experimental_md.index.repeat(k)].reset_index(drop=True) + + output_md = _create_alignment_metadata( + experimental_md=experimental_md, + reference_md=reference_md, + projection_md=projection_md, + match_distances=match_distances, + match_indices=match_indices, + reference_columns=reference_columns, + local_transform_md=local_transform_md + ) + else: + output_md = _update_alignment_metadata( + output_md=output_md, + reference_md=reference_md, + projection_md=projection_md, + match_distances=match_distances, + match_indices=match_indices, + reference_columns=reference_columns, + local_transform_md=local_transform_md + ) + + assert(output_md is not None) + return output_md diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py new file mode 100644 index 000000000..7b355d085 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py @@ -0,0 +1,71 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Iterable +import pandas as pd + +from .. import search +from .. import metadata as md +from .InPlaneTransformBatch import InPlaneTransformBatch + +def populate(db: search.Database, + dataset: Iterable[InPlaneTransformBatch] ) -> pd.DataFrame: + + # Empty the database + db.reset() + + database_device = db.get_input_device() + + # Create arrays for appending MD + reference_indices = [] + angles = [] + shifts_x = [] + shifts_y = [] + + # Add elements + for batch in dataset: + # Populate the database + db.add(batch.vectors.to(database_device)) + + # Fill the metadata + reference_indices += batch.indices.tolist() + angles += [batch.angle] * len(batch.indices) + shifts_x += [float(batch.shift[0])] * len(batch.indices) + shifts_y += [float(batch.shift[1])] * len(batch.indices) + + # Create the output md + COLUMNS = [ + md.REF, + md.ANGLE_PSI, + md.SHIFT_X, + md.SHIFT_Y + ] + assert(len(reference_indices) == len(angles)) + assert(len(reference_indices) == len(shifts_x)) + assert(len(reference_indices) == len(shifts_y)) + + result = pd.DataFrame( + data=zip(reference_indices, angles, shifts_x, shifts_y), + columns=COLUMNS + ) + + return result \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py new file mode 100644 index 000000000..d594c0dcf --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py @@ -0,0 +1,45 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Iterable +import torch + +from .. import search + + +def train(db: search.Database, + dataset: Iterable[torch.Tensor], + scratch: torch.Tensor ): + + # Write + start = 0 + for vectors in dataset: + end = start + len(vectors) + + # Write + scratch[start:end,:] = vectors.to(scratch.device, non_blocking=True) + + # Setup next iteration + start = end + + # Train the database + db.train(scratch[:start]) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py new file mode 100644 index 000000000..126191a2d --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py @@ -0,0 +1,25 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from .basis import dct_ii_basis, dct_iii_basis +from .dct import bases_generator, dct, idct +from .project import project, project_nd \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py new file mode 100644 index 000000000..5f58a71a9 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py @@ -0,0 +1,60 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +import torch +import math + +def _get_nk(N: int): + n = torch.arange(N) + k = n.view(N, 1) + return n, k + +def dct_ii_basis(N: int, norm: bool = True) -> torch.Tensor: + n, k = _get_nk(N) + + result = (n + 0.5) * k + result *= torch.pi / N + result = torch.cos(result, out=result) + + # Normalize + if norm: + result *= math.sqrt(1/N) + result[1:,:] *= math.sqrt(2) + + return result + +def dct_iii_basis(N: int, norm: bool = True) -> torch.Tensor: + n, k = _get_nk(N) + + # TODO avoid computing result[:,0] twice + result = (k + 0.5) * n + result *= torch.pi / N + result = torch.cos(result, out=result) + + if norm: + result[:,0] = 1 / math.sqrt(2) + result *= math.sqrt(2/N) + else: + result[:,0] = 0.5 + + + return result \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py new file mode 100644 index 000000000..645be50d5 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py @@ -0,0 +1,48 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Sequence, Iterable, Callable +import torch + +from .basis import dct_ii_basis, dct_iii_basis +from .project import project_nd + +def bases_generator(shape: Sequence[int], + dims: Iterable[int], + func: Callable[[int], torch.Tensor]) -> Iterable[torch.Tensor]: + sizes = map(shape.__getitem__, dims) + bases = map(func, sizes) + return bases + +def dct(x: torch.Tensor, + dims: Iterable[int], + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + bases = bases_generator(x.shape, dims, dct_ii_basis) + return project_nd(x, dims, bases, out=out) + +def idct(x: torch.Tensor, + dims: Iterable[int], + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + bases = bases_generator(x.shape, dims, dct_iii_basis) + return project_nd(x, dims, bases, out=out) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py new file mode 100644 index 000000000..ac428e6d1 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py @@ -0,0 +1,70 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Iterable +import torch + +def project(x: torch.Tensor, + dim: int, + basis: torch.Tensor, + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + if out is x: + raise Exception('Aliasing between x and out is not supported') + + def t(x: torch.Tensor) -> torch.Tensor: + return torch.transpose(x, dim, -1) + + # Transpose the input to have dim + # on the last axis + x = t(x) + if out is not None: + out = t(out) + + # Perform the projection + out = torch.matmul(basis, x, out=out) + + # Undo the transposition + out = t(out) # Return dim to its place + + return out + +def project_nd(x: torch.Tensor, + dims: Iterable[int], + bases: Iterable[torch.Tensor], + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + temp = None + for i, (dim, basis) in enumerate(zip(dims, bases)): + if i == 0: + # First iteration + out = project(x, dim, basis, out=out) + else: + assert(out is not None) + if temp is None: + temp = out.clone() + else: + temp[...] = out + + out = project(temp, dim, basis, out=out) + + return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py index e0fdbd48d..47a145f50 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py @@ -20,5 +20,7 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ +from .remove_symmetic_half import remove_symmetric_half from .rfftnfreq import rfftnfreq -from .zero_pad import zero_pad \ No newline at end of file +from .time_shift_filter import time_shift_filter +from .zero_pad import zero_pad diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py new file mode 100644 index 000000000..fadce4819 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py @@ -0,0 +1,36 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +import torch + +def remove_symmetric_half(input: torch.Tensor) -> torch.Tensor: + """Removes the conjugate symmetry from a multidimensional Fourier transform + + Args: + input (torch.Tensor): Input tensor + + Returns: + torch.Tensor: Input without symmetry + """ + x_size = input.shape[-1] + half_x_size = x_size // 2 + 1 + return input[...,:half_x_size] diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py index 51330e22a..f101d6d02 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py @@ -52,5 +52,4 @@ def rfftfreq(dim: int) -> torch.Tensor: mesh = torch.meshgrid(*reversed(axis_freq), indexing='xy') return torch.stack(mesh) - - + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py new file mode 100644 index 000000000..6064b930f --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py @@ -0,0 +1,45 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +def time_shift_filter(shift: torch.Tensor, + freq: torch.Tensor, + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + """Generates a multidimensional shift filter in Fourier space + + Args: + shift (torch.Tensor): Shift in samples. (B, n) where shape, where B is the batch size and n is the dimensions + freq (torch.Tensor): Frequency grid in radians. (B, dn, ... dy, dx) + out (Optional[torch.Tensor], optional): Preallocated tensor. Defaults to None. + + Returns: + torch.Tensor: _description_ + """ + + # Fourier time shift theorem: + # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Some_discrete_Fourier_transform_pairs + angles = -torch.matmul(shift, freq) + gain = torch.tensor(1.0).to(angles) # TODO try to avoid using this + out = torch.polar(gain, angles, out=out) + return out diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py new file mode 100644 index 000000000..ed70528cd --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py @@ -0,0 +1,26 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from .complex_normalize import complex_normalize +from .l2_normalize import l2_normalize +from .mu_sigma_normalize import mu_sigma_normalize +from .flat_view_as_real import flat_view_as_real \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py new file mode 100644 index 000000000..8567e32d9 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py @@ -0,0 +1,35 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +def complex_normalize(data: torch.Tensor, + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + if out is data: + out /= torch.abs(out) + + else: + raise NotImplementedError('Only implemented for out=data') + + return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py new file mode 100644 index 000000000..f78e5bfb5 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py @@ -0,0 +1,28 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +import torch + +def flat_view_as_real(input: torch.Tensor) -> torch.Tensor: + real = torch.view_as_real(input) + flat = torch.flatten(real, start_dim=-2, end_dim=-1) + return flat \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py new file mode 100644 index 000000000..ae94862c4 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py @@ -0,0 +1,36 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Union, Sequence, Optional +import torch + +def l2_normalize(data: torch.Tensor, + dim: Union[None, int, Sequence[int]], + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + if out is data: + out /= torch.norm(out, dim=dim, keepdim=True) + + else: + raise NotImplementedError('Only implemented for out=data') + + return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py new file mode 100644 index 000000000..870777d81 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py @@ -0,0 +1,38 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Union, Sequence, Optional +import torch + +def mu_sigma_normalize( data: torch.Tensor, + dim: Union[None, int, Sequence[int]], + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + + if out is data: + std, mean = torch.std_mean(out, dim=dim, keepdim=True) + out -= mean + out /= std + + else: + raise NotImplementedError('Only implemented for out=data') + + return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py new file mode 100644 index 000000000..a898c1c28 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py @@ -0,0 +1,58 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Sequence +import torch + +from .SpectraFlattener import SpectraFlattener + +class DctLowPassFlattener(SpectraFlattener): + def __init__( self, + dim: int, + cutoff: float, + exclude_dc: bool = True, + padded_length: Optional[int] = None, + device: Optional[torch.device] = None ): + SpectraFlattener.__init__( + self, + self._compute_mask(dim, cutoff, exclude_dc), + padded_length=padded_length, + device=device + ) + + def _compute_mask( self, + dim: int, + cutoff: float, + exclude_dc: bool ) -> torch.Tensor: + + # Compute the frequency grid + freq_x = torch.linspace(start=0, end=0.5, steps=dim) + freq_y = freq_x[...,None] + freq2 = freq_x**2 + freq_y**2 + + # Compute the mask + cutoff2 = cutoff ** 2 + mask = freq2.less_equal(cutoff2) + if exclude_dc: + mask[0, 0] = False + + return mask diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py new file mode 100644 index 000000000..3024c0a9f --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py @@ -0,0 +1,47 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +from ..dct import dct_ii_basis, project_nd + +from .Transformer2D import Transformer2D + +class DctTransformer2D(Transformer2D): + DIMS = (-1, -2) # Last two dimensions + + def __init__(self, dim: int, device: Optional[torch.device] = None) -> None: + self._bases = (dct_ii_basis(dim).to(device), )*len(self.DIMS) + + def __call__( self, + input: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + + # To avoid warnings + if out is not None: + out.resize_(0) + + return project_nd(input, dims=self.DIMS, bases=self._bases, out=out) + + def has_complex_output(self) -> bool: + return False \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py new file mode 100644 index 000000000..e978b82f0 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py @@ -0,0 +1,62 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Sequence +import torch + +from .SpectraFlattener import SpectraFlattener +from ..fourier import rfftnfreq + +class FourierLowPassFlattener(SpectraFlattener): + def __init__( self, + dim: Sequence[int], + cutoff: float, + exclude_dc: bool = True, + padded_length: Optional[int] = None, + device: Optional[torch.device] = None ): + SpectraFlattener.__init__( + self, + self._compute_mask(dim, cutoff, exclude_dc), + padded_length=padded_length, + device=device + ) + + def _compute_mask( self, + dim: Sequence[int], + cutoff: float, + exclude_dc: bool ) -> torch.Tensor: + + # Compute the frequency grid + frequency_grid = rfftnfreq(dim) + frequencies2 = torch.sum(frequency_grid**2, dim=0) + + # Compute the mask + cutoff2 = cutoff ** 2 + mask = frequencies2.less_equal(cutoff2) + if exclude_dc: + # Remove symmetric coefficients and DC + mask[:(dim[-2]//2),0] = False + else: + # Remove symmetric coefficients + mask[1:(dim[-2]//2),0] = False + + return mask diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py new file mode 100644 index 000000000..f9a02ccfd --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py @@ -0,0 +1,64 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +from .SpectraFlattener import SpectraFlattener + +class FourierShiftFilter: + def __init__(self, + dim: int, + shifts: torch.Tensor, + flattener: SpectraFlattener, + device: Optional[torch.device] = None): + + freq = self._get_freq(dim, flattener, device=device) + angles = -torch.mm(shifts.to(device), freq) + ONE = torch.tensor(1, dtype=angles.dtype, device=device) + self._filters = torch.polar(ONE, angles) + self._shifts = shifts + + def __call__( self, + input: torch.Tensor, + index: int, + out: Optional[torch.Tensor] = None ): + filter = self._filters[index,:] + out = torch.mul(input, filter, out=out) + return out + + def get_count(self) -> int: + return self._filters.shape[0] + + def get_shift(self, index: int) -> torch.Tensor: + return self._shifts[index] + + def _get_freq( self, + dim: int, + flattener: SpectraFlattener, + device: Optional[torch.device] = None) -> torch.Tensor: + + d = 0.5 / (dim*torch.pi) + freq_x = torch.fft.rfftfreq(dim, d=d, device=device) + freq_y = torch.fft.fftfreq(dim, d=d, device=device) + grid = torch.stack(torch.meshgrid(freq_x, freq_y, indexing='xy')) + return flattener(grid) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py new file mode 100644 index 000000000..b0dd0e774 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py @@ -0,0 +1,40 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +from .Transformer2D import Transformer2D + +class FourierTransformer2D(Transformer2D): + def __call__( self, + input: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + + # To avoid warnings + if out is not None: + out.resize_(0) + + return torch.fft.rfft2(input, out=out) + + def has_complex_output(self) -> bool: + return True \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py new file mode 100644 index 000000000..c4aeafeac --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py @@ -0,0 +1,51 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch +import torchvision.transforms as T + +class ImageRotator: + def __init__( self, + angles: torch.Tensor, + device: Optional[torch.device] = None ): + self._angles = angles + + def __call__( self, + input: torch.Tensor, + index: int, + out: Optional[torch.Tensor] ) -> torch.Tensor: + + # TODO use the matrix + out = T.functional.rotate( + input, + self.get_angle(index), + T.InterpolationMode.BILINEAR, + ) + return out + + def get_count(self) -> int: + return len(self._angles) + + def get_angle(self, index: int) -> float: + return float(self._angles[index]) + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py new file mode 100644 index 000000000..d0205eed8 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py @@ -0,0 +1,53 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +class ImageShifter: + def __init__(self, + shifts: torch.Tensor, + dim: int, + device: Optional[torch.device] = None): + + self._shifts = shifts + self._shifts_in_px = (shifts*dim).to(dtype=int) + + def __call__( self, + input: torch.Tensor, + index: int, + out: Optional[torch.Tensor] = None ): + + out = torch.roll( + input, + shifts=self._shifts_in_px[index].tolist(), + dims=(-1, -2) + ) + + return out + + def get_count(self) -> int: + return self._shifts.shape[0] + + def get_shift(self, index: int) -> torch.Tensor: + return self._shifts[index] + \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py new file mode 100644 index 000000000..cbd3ec79b --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py @@ -0,0 +1,74 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional + +import torch + +from .SpectraFlattener import SpectraFlattener +from .Transformer2D import Transformer2D +from .Weighter import Weighter + +from .. import utils + +class ImageSpectraFlattener: + def __init__(self, + transform: Transformer2D, + flattener: SpectraFlattener, + weighter: Optional[Weighter] = None, + norm: Optional[str] = None ) -> None: + + self._transform = transform + self._flattener = flattener + self._weighter = weighter + self._norm = norm + + self._transformed = None + + def __call__(self, + batch: torch.Tensor, + out: Optional[torch.Tensor] = None ) -> torch.Tensor: + # Normalize image if requested + if self._norm == 'image': + batch = batch.clone() + utils.normalize(batch, dim=(-2, -1)) + + # Compute the fourier transform of the images and flatten and weighten it + self._transformed = self._transform(batch, out=self._transformed) + self._transformed_flat = self._flattener(self._transformed, out=self._transformed_flat) + + # Apply the weights + if self._weighter is not None: + self._transformed_flat = self._weighter(self._transformed_flat, out=self._transformed_flat) + + # Normalize complex numbers if requested + if self._norm == 'complex': + utils.complex_normalize(self._transformed_flat) + + # Elaborate the reference vectors + out = utils.flat_view_as_real(self._transformed_flat) + + # Normalize reference vectors if requested + if self._norm == 'vector': + utils.l2_normalize(out, dim=-1) + + return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py new file mode 100644 index 000000000..c2d771884 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py @@ -0,0 +1,33 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +from .MaskFlattener import MaskFlattener + +class SpectraFlattener(MaskFlattener): + def __init__( self, + mask: torch.Tensor, + padded_length: Optional[int] = None, + device: Optional[torch.device] = None): + super().__init__(mask, padded_length, device) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py new file mode 100644 index 000000000..bc14a18b6 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py @@ -0,0 +1,33 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import torch + +class Transformer2D: + def __call__(self, + input: torch.Tensor, + out: Optional[torch.Tensor] = None): + pass + + def has_complex_output(self) -> bool: + pass \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py index edd37da7b..c43454769 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py @@ -20,4 +20,13 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from .MaskFlattener import MaskFlattener \ No newline at end of file +from .Transformer2D import Transformer2D +from .FourierTransformer2D import FourierTransformer2D +from .DctTransformer2D import DctTransformer2D +from .MaskFlattener import MaskFlattener +from .SpectraFlattener import SpectraFlattener +from .FourierLowPassFlattener import FourierLowPassFlattener +from .DctLowPassFlattener import DctLowPassFlattener +from .ImageRotator import ImageRotator +from .ImageShifter import ImageShifter +from .FourierShiftFilter import FourierShiftFilter diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py new file mode 100644 index 000000000..f87ed3e38 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py @@ -0,0 +1,81 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import NamedTuple +import torch + +class SearchResult(NamedTuple): + indices: torch.IntTensor + distances: torch.Tensor + +class Database: + def __init__(self) -> None: + pass + + def train(self, vectors: torch.Tensor) -> None: + pass + + def add(self, vectors: torch.Tensor) -> None: + pass + + def finalize(self): + pass + + def reset(self): + pass + + def search(self, vectors: torch.Tensor, k: int) -> SearchResult: + pass + + def read(self, path: str): + pass + + def write(self, path: str): + pass + + def to_device(self, device: torch.device): + pass + + def is_trained(self) -> bool: + pass + + def is_populated(self) -> bool: + return self.get_item_count() > 0 + + def is_finalized(self) -> bool: + pass + + def get_dim(self) -> int: + pass + + def get_item_count(self) -> int: + pass + + def get_input_device(self) -> torch.device: + pass + + def _check_input(self, x: torch.Tensor): + if len(x.shape) != 2: + raise RuntimeError('Input should have 2 dimensions (batch and vector)') + + if x.shape[-1] != self.get_dim(): + raise RuntimeError('Input vectors have incorrect size') \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py new file mode 100644 index 000000000..c3c12f90f --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py @@ -0,0 +1,145 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional +import math +import torch +import faiss +import faiss.contrib.torch_utils + +from .Database import Database, SearchResult + +def opq_ifv_pq_recipe(dim: int, size: int = int(3e6), c: float = 16, norm=False): + """ + Values selected using: + https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index + """ + # Determine the parameters + PQ_BLOCK_SIZE = 4 # Coefficients per PQ partition + PQ_MAX_BYTES_PER_VECTOR = 48 # For single precision + PQ_MAX_VECTOR_SIZE = PQ_BLOCK_SIZE * PQ_MAX_BYTES_PER_VECTOR + d = min(PQ_MAX_VECTOR_SIZE, dim//PQ_BLOCK_SIZE*PQ_BLOCK_SIZE) + m = d // PQ_BLOCK_SIZE # Bytes per vector. 48 is the max for GPU and single precision + k = c * math.sqrt(size) # Number of clusters + k = 2 ** math.ceil(math.log2(k)) # Use a power of 2 + + # Elaborate the recipe for the factory + opq = f'OPQ{m}_{d}' # Only supported on CPU + ifv = f'IVF{k}' # In order to support GPU do not use HNSW32 + pq = f'PQ{m}' + recipe = (opq, ifv, pq) + + if norm: + l2norm = 'L2norm' + recipe = (l2norm, ) + recipe + + return ','.join(recipe) + +class FaissDatabase(Database): + def __init__(self, + dim: int = 0, + recipe: Optional[str] = None ) -> None: + + index: Optional[faiss.Index] = None + if recipe and dim: + index = faiss.index_factory(dim, recipe) + + self._index = index + + def train(self, vectors: torch.Tensor) -> None: + self._check_input(vectors) + self._sync(vectors) + self._index.train(vectors) + + def add(self, vectors: torch.Tensor) -> None: + self._check_input(vectors) + self._sync(vectors) + self._index.add(vectors) + + def reset(self): + self._index.reset() + + def search(self, vectors: torch.Tensor, k: int) -> SearchResult: + self._check_input(vectors) + self._sync(vectors) + distances, indices = self._index.search(vectors, k) + return SearchResult(indices=indices, distances=distances) + + def read(self, path: str): + self._index = faiss.read_index(path) + + def write(self, path: str): + faiss.write_index(self._index, path) + + def to_device(self, + device: torch.device, + use_f16: bool = False, + reserve_vecs: int = 0, + use_precomputed = False ): + if device.type == 'cuda': + resources = faiss.StandardGpuResources() + resources.setDefaultNullStreamAllDevices() # To interop with torch + co = faiss.GpuClonerOptions() + co.useFloat16 = use_f16 + co.useFloat16CoarseQuantizer = use_f16 + co.usePrecomputed = use_precomputed + co.reserveVecs = reserve_vecs + + self._index = faiss.index_cpu_to_gpu( + resources, + device.index, + self._index, + co + ) + + elif device.type == 'cpu': + self._index = faiss.index_gpu_to_cpu(self._index) + + else: + raise ValueError('Input device must be CPU or CUDA') + + def is_trained(self) -> bool: + return self._index.is_trained + + def is_finalized(self) -> bool: + return True + + def get_dim(self) -> int: + return self._index.d + + def get_item_count(self) -> int: + return self._index.ntotal + + def get_input_device(self) -> torch.device: + return torch.device('cpu') # TODO determine + + def set_metric_type(self, metric_type: int): + self._index.metric_type = metric_type + + def get_metric_type(self) -> int: + return self._index.metric_type + + def _sync(self, vectors: torch.Tensor): + if vectors.device.type == 'cuda': + stream = torch.cuda.current_stream(vectors.device) + event = stream.record_event() + raise NotImplementedError('We should sync here') \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py new file mode 100644 index 000000000..1acbc8a97 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py @@ -0,0 +1,139 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, List +import torch + +from .Database import Database, SearchResult + +class MedianHashDatabase(Database): + def __init__(self, + dim: int = 0) -> None: + + self._dim = dim + self._median: Optional[torch.Tensor] = None + self._hashes: List[torch.Tensor] = [] + + def train(self, vectors: torch.Tensor) -> None: + self._check_input(vectors) + self._median = torch.median(vectors, dim=0).values + + def add(self, vectors: torch.Tensor) -> None: + self._check_input(vectors) + self._hashes.append(self.__compute_hash(vectors)) + + def reset(self): + self._hashes.clear() + + def search(self, vectors: torch.Tensor, k: int) -> SearchResult: + if k != 1: + raise NotImplementedError('KNN has been only implemented for k=1') + + self._check_input(vectors) + + # Compute the signature of the search vectors + search_hashes = self.__compute_hash(vectors) + + result = None + + xors = None + pop_counts = None + best_candidate = None + mask = None + base_index = 0 + for reference_hashes in self._hashes: + # Perform a XOR for all possible pairs + xors = torch.logical_xor( + search_hashes[:,None,:], + reference_hashes[None,:,:], + out=xors + ) + + # Do pop-count + pop_counts = torch.count_nonzero(xors, dim=-1) # TODO add out=count + + # Find the best candidates + best_candidate = torch.min(pop_counts, dim=-1, out=best_candidate) + + # Evaluate new candidates + if result is None: + result = SearchResult( + indices=best_candidate.indices.clone(), + distances=best_candidate.values.clone() + ) + else: + mask = torch.less(best_candidate.values, result.distances, out=mask) + result.indices[mask] = best_candidate.indices[mask] + base_index + result.distances[mask] = best_candidate.values[mask] + + # Update base index for next batch + base_index += len(reference_hashes) + + # Add a dimension in the end + return SearchResult( + indices=result.indices[...,None], + distances=result.distances[...,None] + ) + + def read(self, path: str): + obj = torch.load(path) + self._dim = obj['dim'] + self._median = obj['median'] + self._hashes = obj['hashes'] + + def write(self, path: str): + obj = { + 'dim': self._dim, + 'median': self._median, + 'hashes': self._hashes + } + torch.save(obj, path) + + def to_device(self, device: torch.device): + def func(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if x is None: + return None + else: + return x.to(device=device) + + self._median = func(self._median) + self._hashes = list(map(func, self._hashes)) + + def is_trained(self) -> bool: + return self._median is not None + + def is_finalized(self) -> bool: + return True + + def get_dim(self) -> int: + return self._dim + + def get_item_count(self) -> int: + return sum(map(len, self._hashes)) + + def get_input_device(self) -> torch.device: + return self._median.device + + def __compute_hash(self, + vectors: torch.Tensor, + out: Optional[torch.BoolTensor] = None ) -> torch.BoolTensor: + return torch.greater(self._median, vectors, out=out) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py new file mode 100644 index 000000000..c5ba86ba2 --- /dev/null +++ b/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py @@ -0,0 +1,25 @@ +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from .Database import Database, SearchResult +from .Faiss import FaissDatabase, opq_ifv_pq_recipe +from .MedianHash import MedianHashDatabase \ No newline at end of file From d6a71698bc9c00ded2f2a59ac1df86f8f16a38a6 Mon Sep 17 00:00:00 2001 From: Oier Lauzirika Zarrabeitia Date: Thu, 16 May 2024 07:43:41 +0000 Subject: [PATCH 15/15] Added removed programs --- .../swiftalign_query/swiftalign_query.py | 320 ++++++++++++++++++ .../swiftalign_train/swiftalign_train.py | 203 +++++++++++ 2 files changed, 523 insertions(+) create mode 100755 src/xmipp/applications/scripts/swiftalign_query/swiftalign_query.py create mode 100755 src/xmipp/applications/scripts/swiftalign_train/swiftalign_train.py diff --git a/src/xmipp/applications/scripts/swiftalign_query/swiftalign_query.py b/src/xmipp/applications/scripts/swiftalign_query/swiftalign_query.py new file mode 100755 index 000000000..966d95f1c --- /dev/null +++ b/src/xmipp/applications/scripts/swiftalign_query/swiftalign_query.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python + +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Tuple, Sequence +import torch +import argparse +import itertools +import time +import pandas as pd + +import xmippPyModules.swiftalign.image as image +import xmippPyModules.swiftalign.search as search +import xmippPyModules.swiftalign.alignment as alignment +import xmippPyModules.swiftalign.operators as operators +import xmippPyModules.swiftalign.fourier as fourier +import xmippPyModules.swiftalign.metadata as md + +def _dataframe_batch_generator(df: pd.DataFrame, batch_size: int) -> pd.DataFrame: + for i in range(0, len(df), batch_size): + start = i + end = start + batch_size + yield df[start:end] + +def _read_weights(path: Optional[str], + flattener: operators.SpectraFlattener, + device: Optional[torch.device] = None ) -> torch.Tensor: + weights = None + if path: + weight_image = image.read(path) + weight_image = torch.tensor(weight_image, device=device) + weight_image = fourier.remove_symmetric_half(weight_image) + weights = flattener(weight_image) + weights = torch.sqrt(weights, out=weights) + + return weights + +def _read_ctf(path: Optional[str], + flattener: operators.SpectraFlattener, + device: Optional[torch.device] = None ) -> torch.Tensor: + ctfs = None + ctf_md = None + if path: + ctf_md = md.read(path) + ctf_paths = list(map(image.parse_path, ctf_md[md.IMAGE])) + ctf_dataset = image.torch_utils.Dataset(ctf_paths) + ctf_images = torch.utils.data.default_collate([ctf_dataset[i] for i in range(len(ctf_dataset))]) + ctf_images = fourier.remove_symmetric_half(ctf_images) + ctfs = flattener(ctf_images.to(device)) + + return ctfs + +def _calculate_rotations(max_psi: float, + n_rotations: int ) -> torch.Tensor: + + result = None + if n_rotations > 1: + if max_psi >= 180: + # Consider [-180, +180) + result = torch.linspace(-180.0, +180, n_rotations+1)[:-1] + else: + # Consider [-max_psi, +max_psi] + result = torch.linspace(-max_psi, +max_psi, n_rotations) + else: + # No rotations + result = torch.full((1, ), 0.0) + + return result + + +def _calculate_shifts(max_shift: float, + n_shifts: int, + image_size: Tuple[int, int]) -> torch.Tensor: + max_shift_x = max_shift*image_size[0] + max_shift_y = max_shift*image_size[1] + shifts_x = torch.linspace(-max_shift_x, +max_shift_x, n_shifts) + shifts_y = torch.linspace(-max_shift_y, +max_shift_y, n_shifts) + shifts = torch.cartesian_prod(shifts_x, shifts_y) + return shifts + +def run(experimental_md_path: str, + reference_md_path: str, + index_path: str, + ctf_md_path: Optional[str], + weight_image_path: Optional[str], + output_md_path: str, + n_rotations : int, + n_shifts : int, + max_psi : float, + max_shift : float, + cutoff: float, + batch_size: int, + max_size: int, + norm: Optional[str], + local: bool, + local_shift: bool, + drop_na: bool, + reference_labels: Sequence[str], + k: int, + device_names: list, + use_f16: bool, + use_precomputed: bool ): + + # Devices + if device_names: + devices = list(map(torch.device, device_names)) + else: + devices = [torch.device('cpu')] + + transform_device = torch.device('cpu') + db_device = devices[0] + + # Read input files + experimental_md = md.sort_by_image_filename(md.read(experimental_md_path)) + reference_md = md.sort_by_image_filename(md.read(reference_md_path)) + image_size = md.get_image2d_size(experimental_md) + + # Read the database + db = search.FaissDatabase() + db.read(index_path) + db.to_device(db_device, use_f16=use_f16, reserve_vecs=max_size, use_precomputed=use_precomputed) + + # Create the in-plane transforms + angles = _calculate_rotations(max_psi=max_psi, n_rotations=n_rotations) + shifts = _calculate_shifts(max_shift=max_shift, n_shifts=n_shifts, image_size=image_size) + n_transform = len(angles) * len(shifts) + print(f'Performing {n_transform} transformations to each reference image') + + # Create the band flattener + flattener = operators.FourierLowPassFlattener( + dim=image_size, + cutoff=cutoff, + exclude_dc=True, + device=transform_device + ) + + # Read weights + weights = _read_weights(weight_image_path, flattener, transform_device) + + # Read CTFs + ctfs = _read_ctf(ctf_md_path, flattener, transform_device) + + # Create the transformers + reference_transformer = alignment.FourierInPlaneTransformGenerator( + dim=image_size, + angles=angles, + shifts=shifts, + flattener=flattener, + ctfs=ctfs, + weights=weights, + norm=norm, + device=transform_device + ) + experimental_transformer = alignment.FourierInPlaneTransformCorrector( + flattener=flattener, + weights=weights, + norm=norm, + device=transform_device + ) + + # Create the reference dataset + reference_paths = list(map(image.parse_path, reference_md[md.IMAGE])) + reference_dataset = image.torch_utils.Dataset(reference_paths) + experimental_paths = list(map(image.parse_path, experimental_md[md.IMAGE])) + experimental_dataset = image.torch_utils.Dataset(experimental_paths) + n_total = len(reference_dataset) * n_transform + print(f'In total we will consider {n_total} transformed references') + + # Create the loaders + pin_memory = transform_device.type == 'cuda' + reference_loader = torch.utils.data.DataLoader( + reference_dataset, + batch_size=batch_size, + pin_memory=pin_memory, + num_workers=1 + ) + reference_batch_iterator = iter(reference_transformer(reference_loader)) + + alignment_md = None + n_batches_per_iteration = max(1, max_size // min(batch_size, len(reference_dataset))) + if local: + local_columns = [md.ANGLE_PSI, md.SHIFT_X, md.SHIFT_Y] + elif local_shift: + local_columns = [md.SHIFT_X, md.SHIFT_Y] + else: + local_columns = [] + local_transform_md = experimental_md[local_columns] + populate_time = 0.0 + alignment_time = 0.0 + while True: + + print('Uploading') + start_time = time.perf_counter() + projection_md = alignment.populate( + db, + dataset=itertools.islice(reference_batch_iterator, n_batches_per_iteration) + ) + end_time = time.perf_counter() + populate_time += end_time - start_time + + + if len(projection_md) == 0: + break + + experimental_loader = torch.utils.data.DataLoader( + experimental_dataset, + batch_size=batch_size, + pin_memory=pin_memory, + num_workers=1 + ) + + print('Aligning') + start_time = time.perf_counter() + matches = alignment.align( + db, + experimental_transformer(zip(experimental_loader, _dataframe_batch_generator(local_transform_md, batch_size))), + k=k + ) + + alignment_md = alignment.generate_alignment_metadata( + experimental_md=experimental_md, + reference_md=reference_md, + projection_md=projection_md, + matches=matches, + local_transform_md=local_transform_md, + reference_columns=reference_labels, + output_md=alignment_md + ) + end_time = time.perf_counter() + alignment_time += end_time - start_time + + + print('Populate time (s): ' + str(populate_time)) + print('Alignment time (s): ' + str(alignment_time)) + print('Alignment time per particle (ms/part.): ' + str(alignment_time*1e3/len(experimental_dataset))) + + if drop_na: + alignment_md.dropna(inplace=True) + + alignment_md.sort_index(axis=0, inplace=True) + md.write(alignment_md, output_md_path) + + + +if __name__ == '__main__': + # Define the input + parser = argparse.ArgumentParser( + prog = 'Align Nearest Neighbor Training', + description = 'Align Cryo-EM images using a fast Nearest Neighbor approach') + parser.add_argument('-i', required=True) + parser.add_argument('-r', required=True) + parser.add_argument('-o', required=True) + parser.add_argument('--weights') + parser.add_argument('--index', required=True) + parser.add_argument('--ctf', type=str) + parser.add_argument('--rotations', type=int, required=True) + parser.add_argument('--shifts', type=int, required=True) + parser.add_argument('--max_shift', type=float, required=True) + parser.add_argument('--max_psi', type=float, default=180.0) + parser.add_argument('--max_frequency', type=float, required=True) + parser.add_argument('--batch', type=int, default=1024) + parser.add_argument('--norm', type=str) + parser.add_argument('--local', action='store_true') + parser.add_argument('--local_shift', action='store_true') + parser.add_argument('--dropna', action='store_true') + parser.add_argument('--reference_labels', type=str, nargs='*') + parser.add_argument('-k', type=int, default=1) + parser.add_argument('--devices', nargs='*') + parser.add_argument('--max_size', type=int, default=int(2e6)) + parser.add_argument('--fp16', action='store_true') + parser.add_argument('--use_precomputed', action='store_true') + + # Parse + args = parser.parse_args() + + # Run the program + run( + experimental_md_path = args.i, + reference_md_path = args.r, + index_path = args.index, + ctf_md_path = args.ctf, + weight_image_path = args.weights, + output_md_path = args.o, + n_rotations = args.rotations, + n_shifts = args.shifts, + max_shift = args.max_shift, + max_psi = args.max_psi, + cutoff = args.max_frequency, + batch_size = args.batch, + max_size = args.max_size, + local = args.local, + local_shift = args.local_shift, + norm = args.norm, + drop_na = args.dropna, + reference_labels = args.reference_labels, + k = args.k, + device_names = args.devices, + use_f16 = args.fp16, + use_precomputed=args.use_precomputed + ) \ No newline at end of file diff --git a/src/xmipp/applications/scripts/swiftalign_train/swiftalign_train.py b/src/xmipp/applications/scripts/swiftalign_train/swiftalign_train.py new file mode 100755 index 000000000..9b8db47b6 --- /dev/null +++ b/src/xmipp/applications/scripts/swiftalign_train/swiftalign_train.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python + +# *************************************************************************** +# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 2 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'xmipp@cnb.csic.es' +# ***************************************************************************/ + +from typing import Optional, Iterable +import argparse +import math +import torch + +import xmippPyModules.swiftalign.image as image +import xmippPyModules.swiftalign.operators as operators +import xmippPyModules.swiftalign.fourier as fourier +import xmippPyModules.swiftalign.search as search +import xmippPyModules.swiftalign.alignment as alignment +import xmippPyModules.swiftalign.metadata as md + +def _read_weights(path: Optional[str], + flattener: operators.SpectraFlattener, + device: Optional[torch.device] = None ) -> torch.Tensor: + weights = None + if path: + weight_image = image.read(path) + weight_image = torch.tensor(weight_image, device=device) + weight_image = fourier.remove_symmetric_half(weight_image) + weights = flattener(weight_image) + weights = torch.sqrt(weights, out=weights) + + return weights + +def _read_ctf(path: Optional[str], + flattener: operators.SpectraFlattener, + device: Optional[torch.device] = None ) -> torch.Tensor: + ctfs = None + ctf_md = None + if path: + ctf_md = md.read(path) + ctf_paths = list(map(image.parse_path, ctf_md[md.IMAGE])) + ctf_dataset = image.torch_utils.Dataset(ctf_paths) + ctf_images = torch.utils.data.default_collate([ctf_dataset[i] for i in range(len(ctf_dataset))]) + ctf_images = fourier.remove_symmetric_half(ctf_images) + ctfs = flattener(ctf_images.to(device)) + + return ctfs + +def run(reference_md_path: str, + index_path: str, + recipe: str, + ctf_md_path: Optional[str], + weight_image_path: Optional[str], + max_shift : float, + max_psi: float, + cutoff: float, + norm: Optional[str], + n_training: int, + n_batch: int, + device_names: list, + scratch_path: Optional[str], + use_f16: bool, + use_precomputed: bool ): + + # Devices + if device_names: + devices = list(map(torch.device, device_names)) + else: + devices = [torch.device('cpu')] + + transform_device = devices[0] + db_device = devices[0] + + # Read input files + reference_md = md.sort_by_image_filename(md.read(reference_md_path)) + image_size = md.get_image2d_size(reference_md) + + # Create the flattener + flattener = operators.FourierLowPassFlattener( + dim=image_size, + cutoff=cutoff, + exclude_dc=True, + device=transform_device + ) + + # Read weights + weights = _read_weights(weight_image_path, flattener, transform_device) + + # Read CTFs + ctfs = _read_ctf(ctf_md_path, flattener, transform_device) + + # Create the transformer + transformer = alignment.FourierInPlaneTransformAugmenter( + max_psi=max_psi, + max_shift=max_shift, + flattener=flattener, + ctfs=ctfs, + weights=weights, + norm=norm + ) + + # Create the image loader + image_paths = list(map(image.parse_path, reference_md[md.IMAGE])) + dataset = image.torch_utils.Dataset(image_paths) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=n_batch, + pin_memory=transform_device.type=='cuda', + num_workers=1 + ) + + # Create the image transformer + n_repetitions = n_training // len(dataset) + n_training = n_repetitions * len(dataset) + + # Create the DB to store the data + dim = flattener.get_length()*2 + print(f'Data dimensions: {dim}') + db = search.FaissDatabase(dim, recipe) + db.to_device(db_device, use_f16=use_f16, use_precomputed=use_precomputed) + + # Create the storage for the training set. + # This will be LARGE. Therefore provide a MMAP path + training_set_shape = (n_training, dim) + if scratch_path: + size = math.prod(training_set_shape) + storage = torch.FloatStorage.from_file(scratch_path, shared=True, size=size) + training_set = torch.FloatTensor(storage=storage) + training_set = training_set.view(training_set_shape) + else: + training_set = torch.empty(training_set_shape, device=torch.device('cpu')) + + # Run the training + uploader = map(lambda x : x.to(transform_device, non_blocking=True), loader) + alignment.train( + db, + dataset=transformer(uploader, times=n_repetitions), + scratch=training_set + ) + + # Write to disk + db.to_device(torch.device('cpu')) + db.write(index_path) + + +if __name__ == '__main__': + # Define the input + parser = argparse.ArgumentParser( + prog = 'Align Nearest Neighbor Training', + description = 'Align Cryo-EM images using a fast Nearest Neighbor approach') + parser.add_argument('-i', type=str, required=True) + parser.add_argument('-o', type=str, required=True) + parser.add_argument('--recipe', type=str, required=True) + parser.add_argument('--ctf', type=str) + parser.add_argument('--weights', type=str) + parser.add_argument('--max_shift', type=float, required=True) + parser.add_argument('--max_psi', type=float, default=180.0) + parser.add_argument('--max_frequency', type=float, required=True) + parser.add_argument('--norm', type=str) + parser.add_argument('--training', type=int, default=int(4e6)) + parser.add_argument('--batch', type=int, default=int(1024)) + parser.add_argument('--device', nargs='*') + parser.add_argument('--scratch', type=str) + parser.add_argument('--fp16', action='store_true') + parser.add_argument('--use_precomputed', action='store_true') + + # Parse + args = parser.parse_args() + + # Run the program + run( + reference_md_path = args.i, + index_path = args.o, + recipe = args.recipe, + ctf_md_path = args.ctf, + weight_image_path = args.weights, + max_shift = args.max_shift, + max_psi = args.max_psi, + cutoff = args.max_frequency, + norm = args.norm, + n_training = args.training, + n_batch = args.batch, + device_names = args.device, + scratch_path=args.scratch, + use_f16=args.fp16, + use_precomputed=args.use_precomputed + ) \ No newline at end of file