diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py index 9842993a0..0ce46cb97 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/__init__.py @@ -20,10 +20,9 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from . import dct +from . import alignment +from . import classification from . import image from . import metadata -from . import operators -from . import search from . import transform from . import utils \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py deleted file mode 100644 index dba34f6a7..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformAugmenter.py +++ /dev/null @@ -1,81 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Iterable, Optional -import torch -import torchvision.transforms as T - -import random - -from .. import operators -from .. import math - -class FourierInPlaneTransformAugmenter: - def __init__(self, - max_psi: float, - max_shift: float, - flattener: operators.SpectraFlattener, - ctfs: Optional[torch.Tensor] = None, - weights: Optional[torch.Tensor] = None, - norm: Optional[str] = None, - interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR ) -> None: - - # Random affine transformer - self.random_affine = T.RandomAffine( - degrees=max_psi, - translate=(max_shift, )*2, - interpolation=interpolation - ) - - # Operations - self.fourier = operators.FourierTransformer2D() - self.flattener = flattener - self.ctfs = ctfs - self.weights = weights - self.norm = norm - - def __call__(self, - images: Iterable[torch.Tensor], - times: int = 1) -> torch.Tensor: - - if self.norm: - raise NotImplementedError('Normalization is not implemented') - - images_affine = None - images_fourier_transform = None - images_band = None - - for batch in images: - for _ in range(times): - images_affine = self.random_affine(batch) - images_fourier_transform = self.fourier(images_affine, out=images_fourier_transform) - images_band = self.flattener(images_fourier_transform, out=images_band) - - if self.weights is not None: - images_band *= self.weights - - if self.ctfs is not None: - # Select a random CTF and apply it - ctf = self.ctfs[random.randrange(len(self.ctfs))] - images_band *= ctf - - yield math.flat_view_as_real(images_band) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py deleted file mode 100644 index 1b36d21b6..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformCorrector.py +++ /dev/null @@ -1,111 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Iterable, Optional, Tuple -import pandas as pd -import torch - -from .. import operators -from .. import transform -from .. import math -from .. import metadata as md - -class FourierInPlaneTransformCorrector: - def __init__(self, - flattener: operators.SpectraFlattener, - weights: Optional[torch.Tensor] = None, - norm: Optional[str] = None, - interpolation: str = 'bilinear', - device: Optional[torch.device] = None ) -> None: - - # Operations - self.fourier = operators.FourierTransformer2D() - self.flattener = flattener - self.weights = weights - self.interpolation = interpolation - self.norm = norm - - # Cached variables - self.rotated_images = None - self.rotated_fourier_transform = None - self.rotated_band = None - self.shifted_rotated_band = None - - # Device - self.device = device - - def __call__(self, - images: Iterable[Tuple[torch.Tensor, pd.DataFrame]]) -> torch.Tensor: - - if self.norm: - raise NotImplementedError('Normalization is not implemented') - - angles = None - transform_matrix = None - fourier_transforms = None - bands = None - - TRANSFORM_LABELS = [md.ANGLE_PSI, md.SHIFT_X, md.SHIFT_Y] - - for batch_images, batch_md in images: - if len(batch_images) != len(batch_md): - raise RuntimeError('Metadata and image batch sizes do not match') - - if self.device is not None: - batch_images = batch_images.to(self.device, non_blocking=True) - - if all(map(batch_md.columns.__contains__, TRANSFORM_LABELS)): - transformations = torch.as_tensor(batch_md[TRANSFORM_LABELS].to_numpy(), dtype=torch.float32) - angles = torch.deg2rad(transformations[:,0], out=angles) - shifts = transformations[:,1:] - centre = torch.tensor(batch_images.shape[-2:]) / 2 - - transform_matrix = transform.affine_matrix_2d( - angles=angles, - shifts=shifts, - centre=centre, - shift_first=True, - out=transform_matrix - ) - - batch_images = transform.affine_2d( - images=batch_images, - matrices=transform_matrix.to(batch_images, non_blocking=True), - interpolation=self.interpolation, - padding='zeros' - ) - - fourier_transforms = self.fourier( - batch_images, - out=fourier_transforms - ) - - bands = self.flattener( - fourier_transforms, - out=bands - ) - - if self.weights is not None: - bands *= self.weights - - yield math.flat_view_as_real(bands) - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py deleted file mode 100644 index 19d4605d0..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/FourierInPlaneTransformGenerator.py +++ /dev/null @@ -1,141 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Sequence, Iterable, Optional -import torch -import torchvision.transforms as T -import torchvision.transforms.functional as F -import itertools - -from .. import operators -from .. import math -from .. import fourier -from .InPlaneTransformBatch import InPlaneTransformBatch - -def _compute_shift_filters( shifts: torch.Tensor, - flattener: operators.SpectraFlattener, - dim: Sequence[int], - device: Optional[torch.device] = None ) -> torch.Tensor: - d = 1.0/(2*torch.pi) - frequency_grid = fourier.rfftnfreq(dim, d=d, device=device) - frequency_coefficients = flattener(frequency_grid) - return fourier.time_shift_filter(shifts.to(frequency_coefficients.device), frequency_coefficients) - -class FourierInPlaneTransformGenerator: - def __init__(self, - dim: Sequence[int], - angles: torch.Tensor, - shifts: torch.Tensor, - flattener: operators.SpectraFlattener, - ctfs: Optional[torch.Tensor] = None, - weights: Optional[torch.Tensor] = None, - norm: Optional[str] = None, - interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, - device: Optional[torch.device] = None ) -> None: - - # Transforms - self.angles = angles - self.shifts = shifts - self.shift_filters = _compute_shift_filters(-shifts, flattener, dim, device) # Invert shifts - - # Operations - self.fourier = operators.FourierTransformer2D() - self.flattener = flattener - self.ctfs = ctfs - self.weights = weights - self.interpolation = interpolation - self.norm = norm - - # Device - self.device = device - - def __call__(self, - images: Iterable[torch.Tensor]) -> InPlaneTransformBatch: - - if self.norm: - raise NotImplementedError('Normalization is not implemented') - - indices = None - rotated_images = None - rotated_fourier_transforms = None - rotated_bands = None - ctf_bands = None - shifted_bands = None - - start = 0 - for batch in images: - if self.device is not None: - batch = batch.to(self.device, non_blocking=True) - - end = start + len(batch) - indices = torch.arange(start=start, end=end, out=indices) - - for angle in self.angles: - rotated_images = F.rotate( - batch, - angle=float(angle), - interpolation=self.interpolation - ) - - rotated_fourier_transforms = self.fourier( - rotated_images, - out=rotated_fourier_transforms - ) - - rotated_bands = self.flattener( - rotated_fourier_transforms, - out=rotated_bands - ) - - if self.weights is not None: - rotated_bands *= self.weights - - ctfs = self.ctfs if self.ctfs is not None else itertools.repeat(None, times=1) - for ctf in ctfs: - if ctf is not None: - # Apply the CTF when provided - ctf_bands = torch.mul( - rotated_bands, - ctf, - out=ctf_bands - ) - else: - # No CTF - ctf_bands = rotated_bands - - for shift, shift_filter in zip(self.shifts, self.shift_filters): - shifted_bands = torch.mul( - ctf_bands, - shift_filter, - out=shifted_bands - ) - - yield InPlaneTransformBatch( - indices=indices, - vectors=math.flat_view_as_real(shifted_bands), - angle=float(angle), - shift=shift - ) - - # Advance the counter for the next iteration - start = end - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py deleted file mode 100644 index 42f1ddc06..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/InPlaneTransformBatch.py +++ /dev/null @@ -1,30 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import NamedTuple -import torch - -class InPlaneTransformBatch(NamedTuple): - indices: torch.IntTensor - vectors: torch.Tensor - shift: torch.Tensor - angle: float diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py index 99354ae2a..5e1d064bb 100644 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py +++ b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/__init__.py @@ -20,12 +20,4 @@ # * e-mail address 'xmipp@cnb.csic.es' # ***************************************************************************/ -from .align import align -from .train import train -from .populate import populate -from .generate_alignment_metadata import generate_alignment_metadata - -from .FourierInPlaneTransformAugmenter import FourierInPlaneTransformAugmenter -from .FourierInPlaneTransformGenerator import FourierInPlaneTransformGenerator -from .FourierInPlaneTransformCorrector import FourierInPlaneTransformCorrector from .InPlaneTransformCorrector import InPlaneTransformCorrector \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py deleted file mode 100644 index 1d55c3e1e..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/align.py +++ /dev/null @@ -1,52 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Iterable -import torch - -from .. import search - - -def align(db: search.Database, - dataset: Iterable[torch.Tensor], - k: int ) -> search.SearchResult: - - database_device = db.get_input_device() - - index_vectors = [] - distance_vectors = [] - - for vectors in dataset: - # Search them - s = db.search(vectors.to(database_device), k=k) - - # Add them to the result - index_vectors.append(s.indices.cpu()) - distance_vectors.append(s.distances.cpu()) - - # Concatenate all result vectors - return search.SearchResult( - indices=torch.cat(index_vectors, axis=0), - distances=torch.cat(distance_vectors, axis=0) - ) - - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py deleted file mode 100644 index 26144be22..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/generate_alignment_metadata.py +++ /dev/null @@ -1,166 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Sequence - -import pandas as pd -import numpy as np - -from .. import metadata as md -from .. import search - -def _ensemble_alignment_md(reference_md: pd.DataFrame, - projection_md: pd.DataFrame, - match_distances: np.ndarray, - match_indices: np.ndarray, - reference_columns: Sequence[int], - index: Optional[np.ndarray] = None, - local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: - - result = pd.DataFrame(match_distances, columns=[md.COST], index=index) - - # Left-join the projection metadata to the result - result = result.join(projection_md, on=match_indices) - - # Left-join the reference metadata to the result - result = result.join(reference_md[reference_columns], on=md.REF) - - # Drop the indexing columns - result.drop(columns=md.REF, inplace=True) - - # Accumulate transforms - if local_transform_md is not None: - result[local_transform_md.columns] += local_transform_md - - # Make psi in [-180, 180] - if md.ANGLE_PSI in local_transform_md.columns: - result.loc[result[md.ANGLE_PSI] > +180.0, md.ANGLE_PSI] -= 360.0 - result.loc[result[md.ANGLE_PSI] < -180.0, md.ANGLE_PSI] += 360.0 - - return result - -def _update_alignment_metadata(output_md: pd.DataFrame, - reference_md: pd.DataFrame, - projection_md: pd.DataFrame, - match_distances: np.ndarray, - match_indices: np.ndarray, - reference_columns: Sequence[int], - local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: - # Select the rows to be updated - selection = match_distances < output_md[md.COST].to_numpy() - - if local_transform_md is not None: - local_transform_md = local_transform_md[selection] - - # Do an alignment for the selected rows - alignment_md = _ensemble_alignment_md( - reference_md=reference_md, - projection_md=projection_md, - match_distances=match_distances[selection], - match_indices=match_indices[selection], - reference_columns=reference_columns, - index=np.nonzero(selection)[0], - local_transform_md=local_transform_md - ) - - # Update output - output_md.loc[selection, alignment_md.columns] = alignment_md - - return output_md - -def _create_alignment_metadata(experimental_md: pd.DataFrame, - reference_md: pd.DataFrame, - projection_md: pd.DataFrame, - match_distances: np.ndarray, - match_indices: np.ndarray, - reference_columns: Sequence[int], - local_transform_md: Optional[pd.DataFrame] = None ) -> pd.DataFrame: - - # Use the first match - alignment_md = _ensemble_alignment_md( - reference_md=reference_md, - projection_md=projection_md, - match_distances=match_distances, - match_indices=match_indices, - reference_columns=reference_columns, - local_transform_md=local_transform_md - ) - - # Add the alignment consensus to the output - output_md = experimental_md.drop(columns=alignment_md.columns, errors='ignore') - output_md = output_md.join(alignment_md) - - # Reorder columns for more convenient reading - output_md = output_md.reindex( - columns=experimental_md.columns.union(output_md.columns, sort=False) - ) - - return output_md - -def generate_alignment_metadata(experimental_md: pd.DataFrame, - reference_md: pd.DataFrame, - projection_md: pd.DataFrame, - matches: search.SearchResult, - reference_columns: Sequence[int], - local_transform_md: Optional[pd.DataFrame] = None, - output_md: Optional[pd.DataFrame] = None) -> pd.DataFrame: - - # Rename the reference image column to make it compatible - # with the resulting MD. (No duplicated IMAGE column) - reference_md = reference_md.rename(columns={ - md.IMAGE: md.REFERENCE_IMAGE, - }) - - # Flatten the kNN results into a single dim and provide - # them as a numpy array. Concatenate columns to match - # the concatenation of experimental md - k = matches.indices.shape[1] - match_distances = matches.distances.numpy().flatten() - match_indices = matches.indices.numpy().flatten() - - # Update or generate depending on wether the output is provided - if output_md is None: - # Repeat each row of the experimental MD k times - experimental_md = experimental_md.loc[experimental_md.index.repeat(k)].reset_index(drop=True) - - output_md = _create_alignment_metadata( - experimental_md=experimental_md, - reference_md=reference_md, - projection_md=projection_md, - match_distances=match_distances, - match_indices=match_indices, - reference_columns=reference_columns, - local_transform_md=local_transform_md - ) - else: - output_md = _update_alignment_metadata( - output_md=output_md, - reference_md=reference_md, - projection_md=projection_md, - match_distances=match_distances, - match_indices=match_indices, - reference_columns=reference_columns, - local_transform_md=local_transform_md - ) - - assert(output_md is not None) - return output_md diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py deleted file mode 100644 index 7b355d085..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/populate.py +++ /dev/null @@ -1,71 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Iterable -import pandas as pd - -from .. import search -from .. import metadata as md -from .InPlaneTransformBatch import InPlaneTransformBatch - -def populate(db: search.Database, - dataset: Iterable[InPlaneTransformBatch] ) -> pd.DataFrame: - - # Empty the database - db.reset() - - database_device = db.get_input_device() - - # Create arrays for appending MD - reference_indices = [] - angles = [] - shifts_x = [] - shifts_y = [] - - # Add elements - for batch in dataset: - # Populate the database - db.add(batch.vectors.to(database_device)) - - # Fill the metadata - reference_indices += batch.indices.tolist() - angles += [batch.angle] * len(batch.indices) - shifts_x += [float(batch.shift[0])] * len(batch.indices) - shifts_y += [float(batch.shift[1])] * len(batch.indices) - - # Create the output md - COLUMNS = [ - md.REF, - md.ANGLE_PSI, - md.SHIFT_X, - md.SHIFT_Y - ] - assert(len(reference_indices) == len(angles)) - assert(len(reference_indices) == len(shifts_x)) - assert(len(reference_indices) == len(shifts_y)) - - result = pd.DataFrame( - data=zip(reference_indices, angles, shifts_x, shifts_y), - columns=COLUMNS - ) - - return result \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py b/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py deleted file mode 100644 index d594c0dcf..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/alignment/train.py +++ /dev/null @@ -1,45 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Iterable -import torch - -from .. import search - - -def train(db: search.Database, - dataset: Iterable[torch.Tensor], - scratch: torch.Tensor ): - - # Write - start = 0 - for vectors in dataset: - end = start + len(vectors) - - # Write - scratch[start:end,:] = vectors.to(scratch.device, non_blocking=True) - - # Setup next iteration - start = end - - # Train the database - db.train(scratch[:start]) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py deleted file mode 100644 index cec32742c..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .compute_ctf_image_2d import compute_ctf_image_2d -from .wiener import wiener_2d \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py deleted file mode 100644 index 336abf87d..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/compute_ctf_image_2d.py +++ /dev/null @@ -1,116 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch -import math - -def _compute_defocus_grid_2d(frequency_angle_grid: torch.Tensor, - defocus_average: torch.Tensor, - defocus_difference: torch.Tensor, - astigmatism_angle: torch.Tensor, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - out = torch.sub(frequency_angle_grid, astigmatism_angle[...,None,None], out=out) - out *= 2 - out.cos_() - out *= defocus_difference[...,None,None] - out += defocus_average[...,None,None] - - return out - -def _compute_beam_energy_spread(frequency_magnitude2_grid: torch.Tensor, - chromatic_aberration: torch.Tensor, - wavelength: float, - energy_spread_coefficient: float, - lens_inestability_coefficient: float, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - # http://i2pc.es/coss/Articulos/Sorzano2007a.pdf - # Equation 10 - k = torch.pi / 4 * wavelength * (energy_spread_coefficient + 2*lens_inestability_coefficient) - x = chromatic_aberration * k - x.square_() - x *= -1.0 / math.log(2) - - out = torch.mul(x[...,None,None], frequency_magnitude2_grid.square(), out=out) - out.exp_() - - return out - -def compute_ctf_image_2d(frequency_magnitude2_grid: torch.Tensor, - frequency_angle_grid: torch.Tensor, - defocus_average: torch.Tensor, - defocus_difference: torch.Tensor, - astigmatism_angle: torch.Tensor, - wavelength: float, - spherical_aberration: float, - q0: Optional[float] = None, - chromatic_aberration: Optional[torch.Tensor] = None, - energy_spread_coefficient: Optional[float] = None, - lens_inestability_coefficient: Optional[float] = None, - phase_shift: Optional[float] = None, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - k = 0.5 * spherical_aberration * wavelength * wavelength - - out = _compute_defocus_grid_2d( - frequency_angle_grid=frequency_angle_grid, - defocus_average=defocus_average, - defocus_difference=defocus_difference, - astigmatism_angle=astigmatism_angle, - out=out - ) - - # Compute the phase - out -= k*frequency_magnitude2_grid - out *= (torch.pi * wavelength) * frequency_magnitude2_grid - - # Apply the phase shift if provided - if phase_shift is not None: - out += phase_shift - - # Compute the sin, also considering the inelastic - # difraction factor if provided - if q0 is not None: - out = out.sin() + q0*out.cos() - else: - out.sin_() - - # Apply energy spread envelope - if (chromatic_aberration is not None) and \ - (energy_spread_coefficient is not None) and \ - (lens_inestability_coefficient is not None): - - beam_energy_spread = _compute_beam_energy_spread( - frequency_magnitude2_grid=frequency_magnitude2_grid, - chromatic_aberration=chromatic_aberration, - wavelength=wavelength, - energy_spread_coefficient=energy_spread_coefficient, - lens_inestability_coefficient=lens_inestability_coefficient - ) - out *= beam_energy_spread - - - - return out - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/wiener.py b/src/xmipp/libraries/py_xmipp/swiftalign/ctf/wiener.py deleted file mode 100644 index 3f8426cc9..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/ctf/wiener.py +++ /dev/null @@ -1,48 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Sequence -import torch - -def wiener_2d( direct_filter: torch.Tensor, - inverse_ssnr: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None) -> torch.Tensor: - - # Compute the filter power (|H|²) at the output - if torch.is_complex(direct_filter): - out = torch.abs(direct_filter, out=out) - out.square_() - else: - out = torch.square(direct_filter, out=out) - - # Compute the default value for inverse SSNR if not - # provided - if inverse_ssnr is None: - inverse_ssnr = torch.mean(out, dim=(-2, -1)) - inverse_ssnr *= 0.1 - inverse_ssnr = inverse_ssnr[...,None,None] - - # H* / (|H|² + N/S) - out.add_(inverse_ssnr) - torch.div(torch.conj(direct_filter), out, out=out) - - return out diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py deleted file mode 100644 index 126191a2d..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/dct/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .basis import dct_ii_basis, dct_iii_basis -from .dct import bases_generator, dct, idct -from .project import project, project_nd \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py deleted file mode 100644 index 5f58a71a9..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/dct/basis.py +++ /dev/null @@ -1,60 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -import torch -import math - -def _get_nk(N: int): - n = torch.arange(N) - k = n.view(N, 1) - return n, k - -def dct_ii_basis(N: int, norm: bool = True) -> torch.Tensor: - n, k = _get_nk(N) - - result = (n + 0.5) * k - result *= torch.pi / N - result = torch.cos(result, out=result) - - # Normalize - if norm: - result *= math.sqrt(1/N) - result[1:,:] *= math.sqrt(2) - - return result - -def dct_iii_basis(N: int, norm: bool = True) -> torch.Tensor: - n, k = _get_nk(N) - - # TODO avoid computing result[:,0] twice - result = (k + 0.5) * n - result *= torch.pi / N - result = torch.cos(result, out=result) - - if norm: - result[:,0] = 1 / math.sqrt(2) - result *= math.sqrt(2/N) - else: - result[:,0] = 0.5 - - - return result \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py deleted file mode 100644 index 645be50d5..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/dct/dct.py +++ /dev/null @@ -1,48 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Sequence, Iterable, Callable -import torch - -from .basis import dct_ii_basis, dct_iii_basis -from .project import project_nd - -def bases_generator(shape: Sequence[int], - dims: Iterable[int], - func: Callable[[int], torch.Tensor]) -> Iterable[torch.Tensor]: - sizes = map(shape.__getitem__, dims) - bases = map(func, sizes) - return bases - -def dct(x: torch.Tensor, - dims: Iterable[int], - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - bases = bases_generator(x.shape, dims, dct_ii_basis) - return project_nd(x, dims, bases, out=out) - -def idct(x: torch.Tensor, - dims: Iterable[int], - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - bases = bases_generator(x.shape, dims, dct_iii_basis) - return project_nd(x, dims, bases, out=out) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py b/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py deleted file mode 100644 index ac428e6d1..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/dct/project.py +++ /dev/null @@ -1,70 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Iterable -import torch - -def project(x: torch.Tensor, - dim: int, - basis: torch.Tensor, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - if out is x: - raise Exception('Aliasing between x and out is not supported') - - def t(x: torch.Tensor) -> torch.Tensor: - return torch.transpose(x, dim, -1) - - # Transpose the input to have dim - # on the last axis - x = t(x) - if out is not None: - out = t(out) - - # Perform the projection - out = torch.matmul(basis, x, out=out) - - # Undo the transposition - out = t(out) # Return dim to its place - - return out - -def project_nd(x: torch.Tensor, - dims: Iterable[int], - bases: Iterable[torch.Tensor], - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - temp = None - for i, (dim, basis) in enumerate(zip(dims, bases)): - if i == 0: - # First iteration - out = project(x, dim, basis, out=out) - else: - assert(out is not None) - if temp is None: - temp = out.clone() - else: - temp[...] = out - - out = project(temp, dim, basis, out=out) - - return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py deleted file mode 100644 index 2ef3cb700..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .remove_symmetic_half import remove_symmetric_half -from .rfftnfreq import rfftnfreq -from .time_shift_filter import time_shift_filter diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py deleted file mode 100644 index fadce4819..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/remove_symmetic_half.py +++ /dev/null @@ -1,36 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -import torch - -def remove_symmetric_half(input: torch.Tensor) -> torch.Tensor: - """Removes the conjugate symmetry from a multidimensional Fourier transform - - Args: - input (torch.Tensor): Input tensor - - Returns: - torch.Tensor: Input without symmetry - """ - x_size = input.shape[-1] - half_x_size = x_size // 2 + 1 - return input[...,:half_x_size] diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py deleted file mode 100644 index a7d92918f..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/rfftnfreq.py +++ /dev/null @@ -1,56 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Sequence, Optional -import torch - -def rfftnfreq(dim: Sequence[int], - d: float = 1.0, - dtype: Optional[type] = None, - device: Optional[torch.device] = None) -> torch.Tensor: - """Creates a multidimensional Fourier frequency grid - - Args: - dim (Sequence[int]): Image size - d (float, optional): Normalization. Defaults to 1.0. - dtype (Optional[type], optional): Element type. Defaults to float32. - device (Optional[torch.device], optional): Device. Defaults to CPU. - - Returns: - torch.Tensor: _description_ - """ - - def fftfreq(dim: int) -> torch.Tensor: - return torch.fft.fftfreq(dim, d=d, dtype=dtype, device=device) - - def rfftfreq(dim: int) -> torch.Tensor: - return torch.fft.rfftfreq(dim, d=d, dtype=dtype, device=device) - - # Compute the frequencies for each axis. - # For the last axis use rfft - axis_freq = list(map(fftfreq, dim[:-1])) - axis_freq.append(rfftfreq(dim[-1])) - - mesh = torch.meshgrid(*reversed(axis_freq), indexing='xy') - return torch.stack(mesh) - - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py b/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py deleted file mode 100644 index 6064b930f..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/fourier/time_shift_filter.py +++ /dev/null @@ -1,45 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -def time_shift_filter(shift: torch.Tensor, - freq: torch.Tensor, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - """Generates a multidimensional shift filter in Fourier space - - Args: - shift (torch.Tensor): Shift in samples. (B, n) where shape, where B is the batch size and n is the dimensions - freq (torch.Tensor): Frequency grid in radians. (B, dn, ... dy, dx) - out (Optional[torch.Tensor], optional): Preallocated tensor. Defaults to None. - - Returns: - torch.Tensor: _description_ - """ - - # Fourier time shift theorem: - # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Some_discrete_Fourier_transform_pairs - angles = -torch.matmul(shift, freq) - gain = torch.tensor(1.0).to(angles) # TODO try to avoid using this - out = torch.polar(gain, angles, out=out) - return out diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py deleted file mode 100644 index ed70528cd..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/math/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .complex_normalize import complex_normalize -from .l2_normalize import l2_normalize -from .mu_sigma_normalize import mu_sigma_normalize -from .flat_view_as_real import flat_view_as_real \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py deleted file mode 100644 index 8567e32d9..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/math/complex_normalize.py +++ /dev/null @@ -1,35 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -def complex_normalize(data: torch.Tensor, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - if out is data: - out /= torch.abs(out) - - else: - raise NotImplementedError('Only implemented for out=data') - - return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py deleted file mode 100644 index f78e5bfb5..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/math/flat_view_as_real.py +++ /dev/null @@ -1,28 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -import torch - -def flat_view_as_real(input: torch.Tensor) -> torch.Tensor: - real = torch.view_as_real(input) - flat = torch.flatten(real, start_dim=-2, end_dim=-1) - return flat \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py deleted file mode 100644 index ae94862c4..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/math/l2_normalize.py +++ /dev/null @@ -1,36 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Union, Sequence, Optional -import torch - -def l2_normalize(data: torch.Tensor, - dim: Union[None, int, Sequence[int]], - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - if out is data: - out /= torch.norm(out, dim=dim, keepdim=True) - - else: - raise NotImplementedError('Only implemented for out=data') - - return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py b/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py deleted file mode 100644 index 870777d81..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/math/mu_sigma_normalize.py +++ /dev/null @@ -1,38 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Union, Sequence, Optional -import torch - -def mu_sigma_normalize( data: torch.Tensor, - dim: Union[None, int, Sequence[int]], - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - - if out is data: - std, mean = torch.std_mean(out, dim=dim, keepdim=True) - out -= mean - out /= std - - else: - raise NotImplementedError('Only implemented for out=data') - - return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py deleted file mode 100644 index a898c1c28..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctLowPassFlattener.py +++ /dev/null @@ -1,58 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Sequence -import torch - -from .SpectraFlattener import SpectraFlattener - -class DctLowPassFlattener(SpectraFlattener): - def __init__( self, - dim: int, - cutoff: float, - exclude_dc: bool = True, - padded_length: Optional[int] = None, - device: Optional[torch.device] = None ): - SpectraFlattener.__init__( - self, - self._compute_mask(dim, cutoff, exclude_dc), - padded_length=padded_length, - device=device - ) - - def _compute_mask( self, - dim: int, - cutoff: float, - exclude_dc: bool ) -> torch.Tensor: - - # Compute the frequency grid - freq_x = torch.linspace(start=0, end=0.5, steps=dim) - freq_y = freq_x[...,None] - freq2 = freq_x**2 + freq_y**2 - - # Compute the mask - cutoff2 = cutoff ** 2 - mask = freq2.less_equal(cutoff2) - if exclude_dc: - mask[0, 0] = False - - return mask diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py deleted file mode 100644 index 3024c0a9f..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/DctTransformer2D.py +++ /dev/null @@ -1,47 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -from ..dct import dct_ii_basis, project_nd - -from .Transformer2D import Transformer2D - -class DctTransformer2D(Transformer2D): - DIMS = (-1, -2) # Last two dimensions - - def __init__(self, dim: int, device: Optional[torch.device] = None) -> None: - self._bases = (dct_ii_basis(dim).to(device), )*len(self.DIMS) - - def __call__( self, - input: torch.Tensor, - out: Optional[torch.Tensor] = None) -> torch.Tensor: - - # To avoid warnings - if out is not None: - out.resize_(0) - - return project_nd(input, dims=self.DIMS, bases=self._bases, out=out) - - def has_complex_output(self) -> bool: - return False \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py deleted file mode 100644 index e978b82f0..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierLowPassFlattener.py +++ /dev/null @@ -1,62 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, Sequence -import torch - -from .SpectraFlattener import SpectraFlattener -from ..fourier import rfftnfreq - -class FourierLowPassFlattener(SpectraFlattener): - def __init__( self, - dim: Sequence[int], - cutoff: float, - exclude_dc: bool = True, - padded_length: Optional[int] = None, - device: Optional[torch.device] = None ): - SpectraFlattener.__init__( - self, - self._compute_mask(dim, cutoff, exclude_dc), - padded_length=padded_length, - device=device - ) - - def _compute_mask( self, - dim: Sequence[int], - cutoff: float, - exclude_dc: bool ) -> torch.Tensor: - - # Compute the frequency grid - frequency_grid = rfftnfreq(dim) - frequencies2 = torch.sum(frequency_grid**2, dim=0) - - # Compute the mask - cutoff2 = cutoff ** 2 - mask = frequencies2.less_equal(cutoff2) - if exclude_dc: - # Remove symmetric coefficients and DC - mask[:(dim[-2]//2),0] = False - else: - # Remove symmetric coefficients - mask[1:(dim[-2]//2),0] = False - - return mask diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py deleted file mode 100644 index f9a02ccfd..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierShiftFilter.py +++ /dev/null @@ -1,64 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -from .SpectraFlattener import SpectraFlattener - -class FourierShiftFilter: - def __init__(self, - dim: int, - shifts: torch.Tensor, - flattener: SpectraFlattener, - device: Optional[torch.device] = None): - - freq = self._get_freq(dim, flattener, device=device) - angles = -torch.mm(shifts.to(device), freq) - ONE = torch.tensor(1, dtype=angles.dtype, device=device) - self._filters = torch.polar(ONE, angles) - self._shifts = shifts - - def __call__( self, - input: torch.Tensor, - index: int, - out: Optional[torch.Tensor] = None ): - filter = self._filters[index,:] - out = torch.mul(input, filter, out=out) - return out - - def get_count(self) -> int: - return self._filters.shape[0] - - def get_shift(self, index: int) -> torch.Tensor: - return self._shifts[index] - - def _get_freq( self, - dim: int, - flattener: SpectraFlattener, - device: Optional[torch.device] = None) -> torch.Tensor: - - d = 0.5 / (dim*torch.pi) - freq_x = torch.fft.rfftfreq(dim, d=d, device=device) - freq_y = torch.fft.fftfreq(dim, d=d, device=device) - grid = torch.stack(torch.meshgrid(freq_x, freq_y, indexing='xy')) - return flattener(grid) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py deleted file mode 100644 index b0dd0e774..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/FourierTransformer2D.py +++ /dev/null @@ -1,40 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -from .Transformer2D import Transformer2D - -class FourierTransformer2D(Transformer2D): - def __call__( self, - input: torch.Tensor, - out: Optional[torch.Tensor] = None) -> torch.Tensor: - - # To avoid warnings - if out is not None: - out.resize_(0) - - return torch.fft.rfft2(input, out=out) - - def has_complex_output(self) -> bool: - return True \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py deleted file mode 100644 index c4aeafeac..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageRotator.py +++ /dev/null @@ -1,51 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch -import torchvision.transforms as T - -class ImageRotator: - def __init__( self, - angles: torch.Tensor, - device: Optional[torch.device] = None ): - self._angles = angles - - def __call__( self, - input: torch.Tensor, - index: int, - out: Optional[torch.Tensor] ) -> torch.Tensor: - - # TODO use the matrix - out = T.functional.rotate( - input, - self.get_angle(index), - T.InterpolationMode.BILINEAR, - ) - return out - - def get_count(self) -> int: - return len(self._angles) - - def get_angle(self, index: int) -> float: - return float(self._angles[index]) - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py deleted file mode 100644 index d0205eed8..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageShifter.py +++ /dev/null @@ -1,53 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -class ImageShifter: - def __init__(self, - shifts: torch.Tensor, - dim: int, - device: Optional[torch.device] = None): - - self._shifts = shifts - self._shifts_in_px = (shifts*dim).to(dtype=int) - - def __call__( self, - input: torch.Tensor, - index: int, - out: Optional[torch.Tensor] = None ): - - out = torch.roll( - input, - shifts=self._shifts_in_px[index].tolist(), - dims=(-1, -2) - ) - - return out - - def get_count(self) -> int: - return self._shifts.shape[0] - - def get_shift(self, index: int) -> torch.Tensor: - return self._shifts[index] - \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py deleted file mode 100644 index cbd3ec79b..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/ImageSpectraFlattener.py +++ /dev/null @@ -1,74 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional - -import torch - -from .SpectraFlattener import SpectraFlattener -from .Transformer2D import Transformer2D -from .Weighter import Weighter - -from .. import utils - -class ImageSpectraFlattener: - def __init__(self, - transform: Transformer2D, - flattener: SpectraFlattener, - weighter: Optional[Weighter] = None, - norm: Optional[str] = None ) -> None: - - self._transform = transform - self._flattener = flattener - self._weighter = weighter - self._norm = norm - - self._transformed = None - - def __call__(self, - batch: torch.Tensor, - out: Optional[torch.Tensor] = None ) -> torch.Tensor: - # Normalize image if requested - if self._norm == 'image': - batch = batch.clone() - utils.normalize(batch, dim=(-2, -1)) - - # Compute the fourier transform of the images and flatten and weighten it - self._transformed = self._transform(batch, out=self._transformed) - self._transformed_flat = self._flattener(self._transformed, out=self._transformed_flat) - - # Apply the weights - if self._weighter is not None: - self._transformed_flat = self._weighter(self._transformed_flat, out=self._transformed_flat) - - # Normalize complex numbers if requested - if self._norm == 'complex': - utils.complex_normalize(self._transformed_flat) - - # Elaborate the reference vectors - out = utils.flat_view_as_real(self._transformed_flat) - - # Normalize reference vectors if requested - if self._norm == 'vector': - utils.l2_normalize(out, dim=-1) - - return out \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/MaskFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/MaskFlattener.py deleted file mode 100644 index 27f13ad38..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/MaskFlattener.py +++ /dev/null @@ -1,79 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -class MaskFlattener: - def __init__( self, - mask: torch.Tensor, - padded_length: Optional[int] = None, - device: Optional[torch.device] = None): - self._mask = mask - self._indices = self._calculate_indices(mask, device=device) - self._length = padded_length or len(self._indices) - - def __call__( self, - input: torch.Tensor, - out: Optional[torch.Tensor] = None) -> torch.Tensor: - - # Allocate the output - flatten_start_dim = -len(self.get_mask().shape) - batch_shape = input.shape[:flatten_start_dim] - output_shape = batch_shape + (self.get_length(), ) - out = torch.empty( - output_shape, - device=input.device, - dtype=input.dtype, - out=out - ) - - if input.shape[flatten_start_dim:] != self.get_mask().shape: - raise IndexError('Input has incorrect size') - - # Flatten in the same dims as the mask - flat_input = torch.flatten(input, start_dim=flatten_start_dim) - - # Write to the output - indices = self.get_indices() - k = len(indices) - out[...,:k] = flat_input[...,indices] - out[...,k:] = 0 - - return out - - def get_mask(self) -> torch.BoolTensor: - return self._mask - - def get_indices(self) -> torch.IntTensor: - return self._indices - - def get_length(self) -> int: - return self._length - - def _calculate_indices(self, - mask: torch.BoolTensor, - device: Optional[torch.device] = None ) -> torch.IntTensor: - flat_mask = torch.flatten(mask) - indices = torch.argwhere(flat_mask)[:,0] - return indices.to(device) - diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py deleted file mode 100644 index c2d771884..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/SpectraFlattener.py +++ /dev/null @@ -1,33 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -from .MaskFlattener import MaskFlattener - -class SpectraFlattener(MaskFlattener): - def __init__( self, - mask: torch.Tensor, - padded_length: Optional[int] = None, - device: Optional[torch.device] = None): - super().__init__(mask, padded_length, device) diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py deleted file mode 100644 index bc14a18b6..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/Transformer2D.py +++ /dev/null @@ -1,33 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import torch - -class Transformer2D: - def __call__(self, - input: torch.Tensor, - out: Optional[torch.Tensor] = None): - pass - - def has_complex_output(self) -> bool: - pass \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py deleted file mode 100644 index a42e3ad79..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/operators/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .Transformer2D import Transformer2D -from .FourierTransformer2D import FourierTransformer2D -from .DctTransformer2D import DctTransformer2D -from .MaskFlattener import MaskFlattener -from .SpectraFlattener import SpectraFlattener -from .FourierLowPassFlattener import FourierLowPassFlattener -from .DctLowPassFlattener import DctLowPassFlattener -from .ImageRotator import ImageRotator -from .ImageShifter import ImageShifter -from .FourierShiftFilter import FourierShiftFilter \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py deleted file mode 100644 index f87ed3e38..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/search/Database.py +++ /dev/null @@ -1,81 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import NamedTuple -import torch - -class SearchResult(NamedTuple): - indices: torch.IntTensor - distances: torch.Tensor - -class Database: - def __init__(self) -> None: - pass - - def train(self, vectors: torch.Tensor) -> None: - pass - - def add(self, vectors: torch.Tensor) -> None: - pass - - def finalize(self): - pass - - def reset(self): - pass - - def search(self, vectors: torch.Tensor, k: int) -> SearchResult: - pass - - def read(self, path: str): - pass - - def write(self, path: str): - pass - - def to_device(self, device: torch.device): - pass - - def is_trained(self) -> bool: - pass - - def is_populated(self) -> bool: - return self.get_item_count() > 0 - - def is_finalized(self) -> bool: - pass - - def get_dim(self) -> int: - pass - - def get_item_count(self) -> int: - pass - - def get_input_device(self) -> torch.device: - pass - - def _check_input(self, x: torch.Tensor): - if len(x.shape) != 2: - raise RuntimeError('Input should have 2 dimensions (batch and vector)') - - if x.shape[-1] != self.get_dim(): - raise RuntimeError('Input vectors have incorrect size') \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py deleted file mode 100644 index c3c12f90f..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/search/Faiss.py +++ /dev/null @@ -1,145 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional -import math -import torch -import faiss -import faiss.contrib.torch_utils - -from .Database import Database, SearchResult - -def opq_ifv_pq_recipe(dim: int, size: int = int(3e6), c: float = 16, norm=False): - """ - Values selected using: - https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index - """ - # Determine the parameters - PQ_BLOCK_SIZE = 4 # Coefficients per PQ partition - PQ_MAX_BYTES_PER_VECTOR = 48 # For single precision - PQ_MAX_VECTOR_SIZE = PQ_BLOCK_SIZE * PQ_MAX_BYTES_PER_VECTOR - d = min(PQ_MAX_VECTOR_SIZE, dim//PQ_BLOCK_SIZE*PQ_BLOCK_SIZE) - m = d // PQ_BLOCK_SIZE # Bytes per vector. 48 is the max for GPU and single precision - k = c * math.sqrt(size) # Number of clusters - k = 2 ** math.ceil(math.log2(k)) # Use a power of 2 - - # Elaborate the recipe for the factory - opq = f'OPQ{m}_{d}' # Only supported on CPU - ifv = f'IVF{k}' # In order to support GPU do not use HNSW32 - pq = f'PQ{m}' - recipe = (opq, ifv, pq) - - if norm: - l2norm = 'L2norm' - recipe = (l2norm, ) + recipe - - return ','.join(recipe) - -class FaissDatabase(Database): - def __init__(self, - dim: int = 0, - recipe: Optional[str] = None ) -> None: - - index: Optional[faiss.Index] = None - if recipe and dim: - index = faiss.index_factory(dim, recipe) - - self._index = index - - def train(self, vectors: torch.Tensor) -> None: - self._check_input(vectors) - self._sync(vectors) - self._index.train(vectors) - - def add(self, vectors: torch.Tensor) -> None: - self._check_input(vectors) - self._sync(vectors) - self._index.add(vectors) - - def reset(self): - self._index.reset() - - def search(self, vectors: torch.Tensor, k: int) -> SearchResult: - self._check_input(vectors) - self._sync(vectors) - distances, indices = self._index.search(vectors, k) - return SearchResult(indices=indices, distances=distances) - - def read(self, path: str): - self._index = faiss.read_index(path) - - def write(self, path: str): - faiss.write_index(self._index, path) - - def to_device(self, - device: torch.device, - use_f16: bool = False, - reserve_vecs: int = 0, - use_precomputed = False ): - if device.type == 'cuda': - resources = faiss.StandardGpuResources() - resources.setDefaultNullStreamAllDevices() # To interop with torch - co = faiss.GpuClonerOptions() - co.useFloat16 = use_f16 - co.useFloat16CoarseQuantizer = use_f16 - co.usePrecomputed = use_precomputed - co.reserveVecs = reserve_vecs - - self._index = faiss.index_cpu_to_gpu( - resources, - device.index, - self._index, - co - ) - - elif device.type == 'cpu': - self._index = faiss.index_gpu_to_cpu(self._index) - - else: - raise ValueError('Input device must be CPU or CUDA') - - def is_trained(self) -> bool: - return self._index.is_trained - - def is_finalized(self) -> bool: - return True - - def get_dim(self) -> int: - return self._index.d - - def get_item_count(self) -> int: - return self._index.ntotal - - def get_input_device(self) -> torch.device: - return torch.device('cpu') # TODO determine - - def set_metric_type(self, metric_type: int): - self._index.metric_type = metric_type - - def get_metric_type(self) -> int: - return self._index.metric_type - - def _sync(self, vectors: torch.Tensor): - if vectors.device.type == 'cuda': - stream = torch.cuda.current_stream(vectors.device) - event = stream.record_event() - raise NotImplementedError('We should sync here') \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py deleted file mode 100644 index 1acbc8a97..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/search/MedianHash.py +++ /dev/null @@ -1,139 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from typing import Optional, List -import torch - -from .Database import Database, SearchResult - -class MedianHashDatabase(Database): - def __init__(self, - dim: int = 0) -> None: - - self._dim = dim - self._median: Optional[torch.Tensor] = None - self._hashes: List[torch.Tensor] = [] - - def train(self, vectors: torch.Tensor) -> None: - self._check_input(vectors) - self._median = torch.median(vectors, dim=0).values - - def add(self, vectors: torch.Tensor) -> None: - self._check_input(vectors) - self._hashes.append(self.__compute_hash(vectors)) - - def reset(self): - self._hashes.clear() - - def search(self, vectors: torch.Tensor, k: int) -> SearchResult: - if k != 1: - raise NotImplementedError('KNN has been only implemented for k=1') - - self._check_input(vectors) - - # Compute the signature of the search vectors - search_hashes = self.__compute_hash(vectors) - - result = None - - xors = None - pop_counts = None - best_candidate = None - mask = None - base_index = 0 - for reference_hashes in self._hashes: - # Perform a XOR for all possible pairs - xors = torch.logical_xor( - search_hashes[:,None,:], - reference_hashes[None,:,:], - out=xors - ) - - # Do pop-count - pop_counts = torch.count_nonzero(xors, dim=-1) # TODO add out=count - - # Find the best candidates - best_candidate = torch.min(pop_counts, dim=-1, out=best_candidate) - - # Evaluate new candidates - if result is None: - result = SearchResult( - indices=best_candidate.indices.clone(), - distances=best_candidate.values.clone() - ) - else: - mask = torch.less(best_candidate.values, result.distances, out=mask) - result.indices[mask] = best_candidate.indices[mask] + base_index - result.distances[mask] = best_candidate.values[mask] - - # Update base index for next batch - base_index += len(reference_hashes) - - # Add a dimension in the end - return SearchResult( - indices=result.indices[...,None], - distances=result.distances[...,None] - ) - - def read(self, path: str): - obj = torch.load(path) - self._dim = obj['dim'] - self._median = obj['median'] - self._hashes = obj['hashes'] - - def write(self, path: str): - obj = { - 'dim': self._dim, - 'median': self._median, - 'hashes': self._hashes - } - torch.save(obj, path) - - def to_device(self, device: torch.device): - def func(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - if x is None: - return None - else: - return x.to(device=device) - - self._median = func(self._median) - self._hashes = list(map(func, self._hashes)) - - def is_trained(self) -> bool: - return self._median is not None - - def is_finalized(self) -> bool: - return True - - def get_dim(self) -> int: - return self._dim - - def get_item_count(self) -> int: - return sum(map(len, self._hashes)) - - def get_input_device(self) -> torch.device: - return self._median.device - - def __compute_hash(self, - vectors: torch.Tensor, - out: Optional[torch.BoolTensor] = None ) -> torch.BoolTensor: - return torch.greater(self._median, vectors, out=out) \ No newline at end of file diff --git a/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py b/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py deleted file mode 100644 index c5ba86ba2..000000000 --- a/src/xmipp/libraries/py_xmipp/swiftalign/search/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# *************************************************************************** -# * Authors: Oier Lauzirika Zarrabeitia (oierlauzi@bizkaia.eu) -# * -# * This program is free software; you can redistribute it and/or modify -# * it under the terms of the GNU General Public License as published by -# * the Free Software Foundation; either version 2 of the License, or -# * (at your option) any later version. -# * -# * This program is distributed in the hope that it will be useful, -# * but WITHOUT ANY WARRANTY; without even the implied warranty of -# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# * GNU General Public License for more details. -# * -# * You should have received a copy of the GNU General Public License -# * along with this program; if not, write to the Free Software -# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA -# * 02111-1307 USA -# * -# * All comments concerning this program package may be sent to the -# * e-mail address 'xmipp@cnb.csic.es' -# ***************************************************************************/ - -from .Database import Database, SearchResult -from .Faiss import FaissDatabase, opq_ifv_pq_recipe -from .MedianHash import MedianHashDatabase \ No newline at end of file