diff --git a/xmipptomo/protocols/__init__.py b/xmipptomo/protocols/__init__.py index e35fcef..da5bcb8 100644 --- a/xmipptomo/protocols/__init__.py +++ b/xmipptomo/protocols/__init__.py @@ -40,6 +40,7 @@ from .protocol_peak_high_contrast import XmippProtPeakHighContrast from .protocol_phantom_subtomo import XmippProtPhantomSubtomo from .protocol_phantom_tomo import XmippProtPhantomTomo +from .protocol_project_subtomograms import XmippProtProjectSubtomograms from .protocol_project_top import XmippProtSubtomoProject from .protocol_reconstruct_tomograms import XmippProtReconstructTomograms from .protocol_resizeTS import XmippProtResizeTiltSeries @@ -51,4 +52,4 @@ from .protocol_subtraction_subtomo import XmippProtSubtractionSubtomo from .protocol_subtomo_map_back import XmippProtSubtomoMapBack from .protocol_half_maps_subtomos import XmippProtHalfMapsSubtomo -from .protocol_project_subtomograms import XmippProtProjectSubtomograms + diff --git a/xmipptomo/protocols/protocol_project_subtomograms.py b/xmipptomo/protocols/protocol_project_subtomograms.py index 7751d05..8f63e81 100644 --- a/xmipptomo/protocols/protocol_project_subtomograms.py +++ b/xmipptomo/protocols/protocol_project_subtomograms.py @@ -30,12 +30,14 @@ import os, math from typing import Tuple, Union, TypedDict from emtable import Table +import numpy as np # Scipion em imports from pwem.protocols import EMProtocol -from pwem.objects import SetOfParticles, CTFModel, Float, Volume -from pwem.emlib import MetaData, metadata, MDL_CTF_PHASE_SHIFT, MDL_CTF_DEFOCUSU -from pwem.emlib import MDL_CTF_DEFOCUSV, MDL_CTF_DEFOCUS_ANGLE, MDL_ANGLE_TILT +from pwem.objects import SetOfParticles, CTFModel, Float, Volume, Particle, Transform, String +from pwem.emlib.image import ImageHandler as ih +from pwem.emlib import MetaData, metadata, createEmptyFile, MDL_CTF_PHASE_SHIFT, MDL_CTF_DEFOCUSV, \ + MDL_CTF_DEFOCUS_ANGLE, MDL_ANGLE_TILT from pyworkflow import BETA from pyworkflow.protocol import params from pyworkflow.utils import Message @@ -59,6 +61,14 @@ class XmippProtProjectSubtomograms(EMProtocol, ProtTomoBase): _devStatus = BETA _possibleOutputs = {OUTPUTATTRIBUTE: SetOfParticles} + PROJECTION_AXIS = ['X', 'Y', 'Z'] + AXIS_X = 0 + AXIS_Y = 1 + AXIS_Z = 2 + PROJECTION_DIRECTION = ['Along x, y, or z axis', 'Other directions'] + DIRECTION_AXIS = 0 + DIRECTION_OTHER = 1 + # Form constants METHOD_FOURIER = 0 METHOD_REAL_SPACE = 1 @@ -92,92 +102,180 @@ def _defineParams(self, form): # Generating form form.addSection(label=Message.LABEL_INPUT) - form.addParam('inputSubtomograms', params.PointerParam, pointerClass="SetOfSubTomograms,SetOfVolumes", important=True, - label='Set of subtomograms', help="Set of subtomograms whose projections will be generated.") - form.addParam('hasCtfCorrected', params.BooleanParam, default=True, label='Is CTF corrected?: ', - help='Set this option to True if the input set of subtomograms has no CTF or the CTF has been corrected.') - form.addParam('inputCTF', params.PointerParam, pointerClass="SetOfCTFTomoSeries", condition='not hasCtfCorrected', - label='CTF', help="Select the CTF corresponding to this set of subtomograms.") - form.addParam('defocusDir', params.BooleanParam, default=True, label='Defocus increase with z positive?: ', condition='not hasCtfCorrected', - help='This flag must be put if the defocus increases or decreases along the z-axis. This is required to set the local CTF.') + form.addParam('inputSubtomograms', params.PointerParam, pointerClass="SetOfSubTomograms", important=True, + label='Set of subtomograms', + help="Set of subtomograms whose projections will be generated.") + + form.addParam('hasCtfCorrected', params.BooleanParam, default=True, + label='Is CTF corrected?: ', + help='Set this option to True if the input set of subtomograms has no CTF or the CTF has been corrected.') + + form.addParam('inputCTF', params.PointerParam, pointerClass="SetOfCTFTomoSeries", + condition='not hasCtfCorrected', + label='CTF', + help="Select the CTF corresponding to this set of subtomograms.") + + form.addParam('defocusDir', params.BooleanParam, default=True, + label='Defocus increase with z positive?: ', + condition='not hasCtfCorrected', + help='This flag must be put if the defocus increases or decreases along the z-axis. This is required to set the local CTF.') + + form.addParam('projectionDirection', params.EnumParam, + choices=self.PROJECTION_DIRECTION, + default=self.DIRECTION_AXIS, + display=params.EnumParam.DISPLAY_HLIST, + label='Project x, y, z-axis or other directions') + + form.addParam('projectionAxis', params.EnumParam, choices=self.PROJECTION_AXIS, + condition='projectionDirection==%s' %self.DIRECTION_AXIS, + default=self.AXIS_Z, display=params.EnumParam.DISPLAY_HLIST, + label='Projection direction') + form.addParam('cleanTmps', params.BooleanParam, default=True, label='Clean temporary files: ', expertLevel=params.LEVEL_ADVANCED, help='Clean temporary files after finishing the execution.\nThis is useful to reduce unnecessary disk usage.') - form.addParam('transformMethod', params.EnumParam, display=params.EnumParam.DISPLAY_COMBO, default=self.METHOD_FOURIER, - choices=['Fourier', 'Real space', 'Shears'], label="Transform method: ", expertLevel=params.LEVEL_ADVANCED, - help='Select the algorithm that will be used to obtain the projections.') - + form.addParam('transformMethod', params.EnumParam, display=params.EnumParam.DISPLAY_COMBO, + condition='projectionDirection==%s' % self.DIRECTION_OTHER, + default=self.METHOD_FOURIER, + choices=['Fourier', 'Real space', 'Shears'], + label="Transform method: ", expertLevel=params.LEVEL_ADVANCED, + help='Select the algorithm that will be used to obtain the projections.') + # Parameter group for fourier transform method - fourierGroup = form.addGroup('Fourier parameters', condition=f"transformMethod=={self.METHOD_FOURIER}", expertLevel=params.LEVEL_ADVANCED) + fourierGroup = form.addGroup('Fourier parameters', + condition=f"transformMethod=={self.METHOD_FOURIER} and projectionDirection=={self.DIRECTION_OTHER}", + expertLevel=params.LEVEL_ADVANCED) fourierGroup.addParam('pad', params.IntParam, default=2, label="Pad: ", help="Controls the padding factor.") - fourierGroup.addParam('maxfreq', params.FloatParam, default=0.25, label="Maximum frequency: ", - help="Maximum frequency for the pixels.\nBy default, pixels with frequency more than 0.25 are not considered.") - fourierGroup.addParam('interp', params.EnumParam, default=self.INTERPOLATION_BSPLINE, label="Interpolation method: ", choices=["BSpline", "Nearest", "Linear"], - help="Method for interpolation.\nOptions:\n\nBSpline: Cubic BSpline\nNearest: Nearest Neighborhood\nLinear: Linear BSpline") - + fourierGroup.addParam('maxfreq', params.FloatParam, default=0.25, + label="Maximum frequency: ", + help="Maximum frequency for the pixels.\nBy default, pixels with frequency more than 0.25 are not considered.") + fourierGroup.addParam('interp', params.EnumParam, default=self.INTERPOLATION_BSPLINE, + label="Interpolation method: ", choices=["BSpline", "Nearest", "Linear"], + help="Method for interpolation.\nOptions:\n\nBSpline: Cubic BSpline\nNearest: Nearest Neighborhood\nLinear: Linear BSpline") + # Tilt related parameter group - tiltGroup = form.addGroup('Tilt parameters') + tiltGroup = form.addGroup('Tilt parameters', condition='projectionDirection==%s' % self.DIRECTION_OTHER) tiltGroup.addParam('tiltTypeGeneration', params.EnumParam, display=params.EnumParam.DISPLAY_COMBO, default=self.TYPE_N_SAMPLES, - choices=['NSamples', 'Step', 'Tilt Series'], label="Type of sample generation: ", - help='Select the method for generating samples:\n\n' - '*NSamples*: N samples are generated homogeneously across the whole tilt range.\n' - '*Step*: For the whole tilt range, a sample is generated every N gedrees.\n' - '*Tilt Series*: Given a set of Tilt Series, the angles at which each Tilt Series was taken are used.') - tiltGroup.addParam('tiltRangeNSamples', params.IntParam, condition=f'tiltTypeGeneration=={self.TYPE_N_SAMPLES}', label='Number of samples:', - help='Number of samples to be produced.\nIt has to be 1 or greater.') - tiltGroup.addParam('tiltRangeStep', params.IntParam, condition=f'tiltTypeGeneration=={self.TYPE_STEP}', label='Step:', - help='Number of degrees each sample will be separated from the next.\nIt has to be greater than 0.') - tiltGroup.addParam('tiltRangeTS', params.PointerParam, pointerClass="SetOfTiltSeries", condition=f'tiltTypeGeneration=={self.TYPE_TILT_SERIES}', - label='Set of Tilt Series:', help='Set of Tilt Series where the angles of each Tilt Series will be obtained for the projection.') + choices=['Number of Samples in a range', 'Tilt range and Step', 'Tilt Series angles'], + label="Projection angles defined by: ", + help='Select the method for generating samples:\n\n' + '*Number of Samples*: N samples are generated homogeneously across the whole tilt range.\n' + '*Tilt range and step*: For the whole tilt range, a sample is generated every N gedrees.\n' + '*Tilt Series angles*: Given a set of Tilt Series, the angles at which each Tilt Series was taken are used.') + tiltGroup.addParam('tiltRangeNSamples', params.IntParam, + condition=f'tiltTypeGeneration=={self.TYPE_N_SAMPLES}', + label='Number of samples:', + help='Number of samples to be produced.\nIt has to be 1 or greater.') + + tiltGroup.addParam('tiltRangeTS', params.PointerParam, pointerClass="SetOfTiltSeries", + condition=f'tiltTypeGeneration=={self.TYPE_TILT_SERIES}', + label='Tilt Series:', + help='Set of Tilt Series where the angles of each Tilt Series will be obtained for the projection.') + tiltLine = tiltGroup.addLine("Tilt range (degrees)", condition=f'tiltTypeGeneration!={self.TYPE_TILT_SERIES}', - help='The initial and final values of the range of angles the projection will be produced on.\nDefaults to -60º for initial and 60º for final.') + help='The initial and final values of the range of angles the projection will be produced on.\nDefaults to -60º for initial and 60º for final.') tiltLine.addParam('tiltRangeStart', params.IntParam, default=-60, label='Start: ') tiltLine.addParam('tiltRangeEnd', params.IntParam, default=60, label='End: ') + tiltLine.addParam('tiltRangeStep', params.FloatParam, + condition=f'tiltTypeGeneration=={self.TYPE_STEP}', + label='Step:', + help='Number of degrees each sample will be separated from the next.\nIt has to be greater than 0.') # --------------------------- INSERT steps functions -------------------------------------------- def _insertAllSteps(self): - # Defining list of function ids to be waited by the createOutput function - paramDeps = [] - - if self.tiltTypeGeneration.get() == self.TYPE_TILT_SERIES: - # Obtaining list of angle tilts and rots per Tilt Series - angles = self.getAngleDictionary() - - # Generating a metadata angle file and a param file for each Tilt Series - for tsId in angles: - # Getting metadata angle file path - angleFile = self.getAngleFileAbsolutePath(tsId) - - # Generating metadata angle file with Tilt Series's data - angTable = Table(columns=[self.COLUMN_ANGLE_PSI, self.COLUMN_ANGLE_ROT, self.COLUMN_ANGLE_TILT]) - for values in angles[tsId]: - angTable.addRow( - 0.0, # anglePsi - values[self.COLUMN_ANGLE_ROT], # angleRot - values[self.COLUMN_ANGLE_TILT] # angleTilt - ) - angTable.write(angleFile, tableName='projectionAngles') - - # Generating param file - paramDeps.append(self._insertFunctionStep(self.generateParamFile, tsId)) + if self.projectionDirection.get() == self.DIRECTION_AXIS: + self._insertFunctionStep(self.topSideViewProjectionStep) + self._insertFunctionStep(self.createOutputTopStep) else: - # If type of generation is not from Tilt Series, generate single param file - paramDeps.append(self._insertFunctionStep(self.generateParamFile)) - - # Generating projections for each subtomogram - generationDeps = [] - for subtomogram in self.inputSubtomograms.get(): - tsId = '' if self.tiltTypeGeneration.get() != self.TYPE_TILT_SERIES or len([ts for ts in self.tiltRangeTS.get()]) == 1 else subtomogram.getCoordinate3D().getTomoId() - generationDeps.append(self._insertFunctionStep(self.generateSubtomogramProjections, subtomogram.getFileName(), tsId, prerequisites=paramDeps)) - - # Conditionally removing temporary files - if self.cleanTmps.get(): - self._insertFunctionStep(self.removeTempFiles, prerequisites=generationDeps) - - # Create output - self._insertFunctionStep(self.createOutputStep, prerequisites=generationDeps) + # Defining list of function ids to be waited by the createOutput function + paramDeps = [] + + if self.tiltTypeGeneration.get() == self.TYPE_TILT_SERIES: + angles = self.getAngleDictionary() + print(angles) + # Generating a metadata angle file and a param file for each Tilt Series + for tsId in angles: + # Getting metadata angle file path + angleFile = self.getAngleFileAbsolutePath(tsId) + + # Generating metadata angle file with Tilt Series's data + angTable = Table(columns=[self.COLUMN_ANGLE_PSI, self.COLUMN_ANGLE_ROT, self.COLUMN_ANGLE_TILT]) + for values in angles[tsId]: + angTable.addRow( + 0.0, # anglePsi + values[self.COLUMN_ANGLE_ROT], # angleRot + values[self.COLUMN_ANGLE_TILT] # angleTilt + ) + angTable.write(angleFile, tableName='projectionAngles') + + # Generating param file + paramDeps.append(self._insertFunctionStep(self.generateParamFile, tsId)) + + else: + # If type of generation is not from Tilt Series, generate single param file + paramDeps.append(self._insertFunctionStep(self.generateParamFile)) + + # Generating projections for each subtomogram + generationDeps = [] + for subtomogram in self.inputSubtomograms.get(): + if self.tiltTypeGeneration.get() != self.TYPE_TILT_SERIES or len([ts for ts in self.tiltRangeTS.get()]) == 1: + tsId = '' + else: + tsId = subtomogram.getCoordinate3D().getTomoId() + generationDeps.append(self._insertFunctionStep(self.generateSubtomogramProjections, subtomogram.getFileName(), tsId, prerequisites=paramDeps)) + + # Conditionally removing temporary files + if self.cleanTmps.get(): + self._insertFunctionStep(self.removeTempFiles, prerequisites=generationDeps) + + # Create output + self._insertFunctionStep(self.createOutputStep, prerequisites=generationDeps) # --------------------------- STEPS functions -------------------------------------------- + def topSideViewProjectionStep(self): + input = self.inputSubtomograms.get() + x, y, z = input.getDim() + + dir = self.projectionAxis.get() + + fnProj = self._getExtraPath("projections.mrcs") + createEmptyFile(fnProj, x, y, 1, input.getSize()) + + for subtomo in input.iterItems(): + + fn = "%s@%s" % subtomo.getLocation() + if fn.endswith('.mrc'): + fn += ':mrc' + + vol = ih().read(fn) + img = ih().createImage() + + volData = vol.getData() + + proj = np.empty([x, y]) + + # X axis + if dir == self.AXIS_X: + for zi in range(z): + for yi in range(y): + proj[yi, zi] = np.sum(volData[zi, yi, :]) + # Y axis + elif dir == self.AXIS_Y: + for zi in range(z): + for xi in range(x): + proj[zi, xi] = np.sum(volData[zi, :, xi]) + # Z axis + elif dir == self.AXIS_Z: + for xi in range(x): + for yi in range(y): + proj[yi, xi] = np.sum(volData[:, yi, xi]) + + # Make the projection to be the image data + img.setData(proj) + + # Write the image at a specific slice + img.write('%d@%s' % (subtomo.getObjId(), fnProj)) + def generateParamFile(self, tsId: str=''): """ This function writes the config file for Xmipp Phantom. @@ -299,6 +397,65 @@ def createOutputStep(self): self._defineOutputs(outputSetOfParticles=outputSetOfParticles) self._defineSourceRelation(self.inputSubtomograms, outputSetOfParticles) + def createOutputTopStep(self): + input = self.inputSubtomograms.get() + imgSetOut = self._createSetOfParticles() + imgSetOut.setSamplingRate(input.getSamplingRate()) + imgSetOut.setAlignmentProj() + + listOfDefocus = {} + + if not self.hasCtfCorrected.get(): + for ctf in self.inputCTF.get(): + tsId = ctf.getTsId() + maxtilt = 1000 #This is a non-sense value great enough + defocusU = 0.0 + defocusV = 0.0 + defocusAng = 0.0 + for ti in ctf.getTiltSeries().iterItems(): + absoluteTiltAngle = abs(ti.getTiltAngle()) + if absoluteTiltAngle<=maxtilt: + ctfModel = ctf.getCtfTomoFromTi(ti) + maxtilt = absoluteTiltAngle + defocusU = ctfModel.getDefocusU() + defocusV = ctfModel.getDefocusV() + defocusAng = ctfModel.getDefocusAngle() + phaseShift = ctfModel.getPhaseShift() + listOfDefocus[tsId] = [defocusU, defocusV, defocusAng, phaseShift] + + # Input could be SetOfVolumes or SetOfSubtomograms + for item in input.iterItems(): + idx = item.getObjId() + + p = Particle(objId=item.getObjId()) + p.setLocation(ih._convertToLocation((idx, self._getExtraPath("projections.mrcs")))) + p._subtomogramID = String(idx) + + if item.hasTransform(): + transform = Transform() + transform.setMatrix(item.getTransform().getMatrix()) + p.setTransform(transform) + + if self.hasCtfCorrected: + ctfModel = CTFModel() + ctfModel.setDefocusU(0.0) + ctfModel.setDefocusV(0.0) + ctfModel.setDefocusAngle(0.0) + ctfModel.setPhaseShift(0.0) + else: + subtomoTsId = item.getCoordinate3D().getTomoId() + ctfTs = listOfDefocus[subtomoTsId] + ctfModel = CTFModel() + ctfModel.setDefocusU(ctfTs[0]) + ctfModel.setDefocusV(ctfTs[1]) + ctfModel.setDefocusAngle(ctfTs[2]) + ctfModel.setPhaseShift(ctfTs[3]) + imgSetOut.append(p) + + self._defineOutputs(outputSetOfParticles=imgSetOut) + self._defineSourceRelation(self.inputSubtomograms, imgSetOut) + + # --------------------------- INFO functions -------------------------------------------- def _validate(self): """ @@ -545,14 +702,14 @@ def getClosestCTF(self, tiltSeries: TiltSeries, tiltAngle: Float) -> CTFModel: # Returning closest CTF return outputCTF - + def getAngleDictionary(self) -> TypedDict: """ This function returs a dictionary containing all the angles of each Tilt Series of the input set """ # Initializing empty dictionary angleDict = {} - + if self.tiltTypeGeneration.get() == self.TYPE_TILT_SERIES: for ts in self.tiltRangeTS.get(): tsId = ts.getTsId() diff --git a/xmipptomo/protocols/protocol_project_top.py b/xmipptomo/protocols/protocol_project_top.py index b80cd16..691275f 100644 --- a/xmipptomo/protocols/protocol_project_top.py +++ b/xmipptomo/protocols/protocol_project_top.py @@ -154,7 +154,7 @@ def createOutputStep(self): # Input could be SetOfVolumes or SetOfSubtomograms for item in input.iterItems(): idx = item.getObjId() - p = Particle() + p = Particle(objId=item.getObjId()) p.setLocation(ih._convertToLocation((idx, self._getExtraPath("projections.mrcs")))) p._subtomogramID = String(idx)