Skip to content

Commit

Permalink
Merge branch 'devel' into fixmonotomo
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilax committed Oct 27, 2023
2 parents dbcf749 + a37854e commit dab0362
Show file tree
Hide file tree
Showing 12 changed files with 960 additions and 626 deletions.
165 changes: 108 additions & 57 deletions xmipptomo/protocols/protocol_deep_misalignment_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@
from pwem.emlib import MetaData, MDL_MAX, MDL_MIN
from pwem.protocols import EMProtocol
from pyworkflow import BETA
from pyworkflow.object import Set
from pyworkflow.protocol import PointerParam, EnumParam, FloatParam, BooleanParam, LEVEL_ADVANCED
from pyworkflow.object import Set, Float
from pyworkflow.protocol import PointerParam, EnumParam, FloatParam, BooleanParam, LEVEL_ADVANCED, StringParam, \
GPU_LIST, USE_GPU
from tomo.objects import SetOfCoordinates3D, SetOfTomograms, Coordinate3D, SubTomogram, SetOfSubTomograms
from tomo.protocols import ProtTomoBase
from tomo import constants
from xmipp3 import XmippProtocol
from xmipptomo import utils

COORDINATES_FILE_NAME = 'subtomo_coords.xmd'
BOX_SIZE = 32
TARGET_SAMPLING_RATE = 6.25
COORDINATES_EXTRACTED_FILE_NAME = 'subtomo_coords_extracted.xmd'
TARGET_BOX_SIZE = 32


class XmippProtDeepDetectMisalignment(EMProtocol, ProtTomoBase, XmippProtocol):
Expand All @@ -65,6 +66,22 @@ def __init__(self, **kwargs):

# --------------------------- DEFINE param functions ------------------------
def _defineParams(self, form):
form.addHidden(USE_GPU,
BooleanParam,
default=True,
label="Use GPU for execution",
help="This protocol has both CPU and GPU implementation. "
"Select the one you want to use.")

form.addHidden(GPU_LIST,
StringParam,
default='0',
expertLevel=LEVEL_ADVANCED,
label="Choose GPU IDs",
help="GPU ID. To pick the best available one set 0. "
"For a specific GPU set its number ID "
"(starting from 1).")

form.addSection(label='Input')

form.addParam('inputSetOfCoordinates',
Expand Down Expand Up @@ -95,6 +112,12 @@ def _defineParams(self, form):
help='Tomograms from which extract the fiducials (gold beads) at the specified coordinates '
'locations.')

form.addParam('fiducialSize',
FloatParam,
default=10,
label='Fiducial size (nm)',
help='Peaked gold bead size in nanometers.')

form.addParam('misaliThrBool',
BooleanParam,
default=True,
Expand Down Expand Up @@ -141,6 +164,9 @@ def _defineParams(self, form):
def _insertAllSteps(self):
self.tomoDict = self.getTomoDict()

# Target sampling that fits the fiducial in 16 px (half od the box size to feed the network).
self.targetSamplingRate = self.fiducialSize.get() / 1.6

for key in self.tomoDict.keys():
tomo = self.tomoDict[key]
coordFilePath = self._getExtraPath(os.path.join(tomo.getTsId()), COORDINATES_FILE_NAME)
Expand All @@ -167,50 +193,63 @@ def extractSubtomos(self, key, coordFilePath):
outputPath = self._getExtraPath(os.path.join(tomo.getTsId()))
tomoFn = tomo.getFileName()

dsFactor = self.targetSamplingRate / tomo.getSamplingRate()

paramsExtractSubtomos = {
'tomogram': tomoFn,
'coordinates': coordFilePath,
'boxsize': BOX_SIZE,
'boxsize': TARGET_BOX_SIZE,
'threads': 1, # ***
'outputPath': outputPath,
'downsample': TARGET_SAMPLING_RATE / tomo.getSamplingRate(),
'downsample': dsFactor,
}

argsExtractSubtomos = "--tomogram %(tomogram)s " \
"--coordinates %(coordinates)s " \
"--boxsize %(boxsize)d " \
"--threads %(threads)d " \
"-o %(outputPath)s " \
"--downsample %(downsample)f "
"--downsample %(downsample)f " \
"--normalize " \
"--fixedBoxSize "

self.runJob('xmipp_tomo_extract_subtomograms', argsExtractSubtomos % paramsExtractSubtomos)

def subtomoPrediction(self, key):
tomo = self.tomoDict[key]

subtomoFilePath = self._getExtraPath(os.path.join(tomo.getTsId()), COORDINATES_FILE_NAME)

paramsMisaliPrediction = {
'modelPick': self.modelPick.get(),
'subtomoFilePath': subtomoFilePath
}

argsMisaliPrediction = "--modelPick %(modelPick)d " \
"--subtomoFilePath %(subtomoFilePath)s "

# Set misalignment threshold
if self.misaliThrBool.get():
paramsMisaliPrediction['misaliThr'] = self.misaliThr.get()

argsMisaliPrediction += "--misaliThr %(misaliThr)f "

# Set misalignment criteria
if self.misalignmentCriteria.get() == 1:
argsMisaliPrediction += "--misalignmentCriteriaVotes "

self.runJob('xmipp_deep_misalignment_detection',
argsMisaliPrediction % paramsMisaliPrediction,
env=self.getCondaEnv())
subtomoExtractedXmdFilePath = self._getExtraPath(os.path.join(tomo.getTsId()),
COORDINATES_EXTRACTED_FILE_NAME)

# Check if no coordinates have been extracted in the previous step
if os.path.exists(subtomoExtractedXmdFilePath):
paramsMisaliPrediction = {
'modelPick': self.modelPick.get(),
'subtomoFilePath': subtomoFilePath,
'g': self.getGpuList()[0],
}

argsMisaliPrediction = "--modelPick %(modelPick)d " \
"--subtomoFilePath %(subtomoFilePath)s " \
"-g %(g)s "

# Set misalignment threshold
if self.misaliThrBool.get():
paramsMisaliPrediction['misaliThr'] = self.misaliThr.get()

argsMisaliPrediction += "--misaliThr %(misaliThr)f "

# Set misalignment criteria
if self.misalignmentCriteria.get() == 1:
argsMisaliPrediction += "--misalignmentCriteriaVotes "

self.runJob('xmipp_deep_misalignment_detection',
argsMisaliPrediction % paramsMisaliPrediction,
env=self.getCondaEnv())
else:
self.info("WARNING: NO SUBTOMOGRAM ESTRACTED FOR TOMOGRAM " + tomo.getTsId() + " IMPOSSIBLE TO STUDY " +
"MISALIGNMENT!")

def createOutputStep(self, key, coordFilePath):
tomo = self.tomoDict[key]
Expand All @@ -229,10 +268,10 @@ def createOutputStep(self, key, coordFilePath):
firstPredictionArray, secondPredictionArray = self.readPredictionArrays(outputSubtomoXmdFilePath)
overallPrediction, predictionAverage = self.readTomoScores(outputTomoXmdFilePath)

print("For volume id " + str(tsId) + " obtained prediction from " + str(len(subtomoPathList)) +
" subtomos is " + str(overallPrediction))
self.info("For volume id " + str(tsId) + " obtained prediction from " + str(len(subtomoPathList)) +
" subtomos is " + str(overallPrediction))

tomo._misaliScore = predictionAverage
tomo._misaliScore = Float(predictionAverage)
self.addTomoToOutput(tomo=tomo, overallPrediction=overallPrediction)

for i, subtomoPath in enumerate(subtomoPathList):
Expand All @@ -246,18 +285,18 @@ def createOutputStep(self, key, coordFilePath):
subtomogram = SubTomogram()
subtomogram.setLocation(subtomoPath)
subtomogram.setCoordinate3D(newCoord3D)
subtomogram.setSamplingRate(TARGET_SAMPLING_RATE)
subtomogram.setSamplingRate(self.targetSamplingRate)
subtomogram.setVolName(tomo.getTsId())
subtomogram._strongMisaliScore = firstPredictionArray[i]
subtomogram._weakMisaliScore = secondPredictionArray[i]
subtomogram._strongMisaliScore = Float(firstPredictionArray[i])
subtomogram._weakMisaliScore = Float(secondPredictionArray[i])

self.outputSubtomos.append(subtomogram)
self.outputSubtomos.write()
self._store()
self.outputSubtomos.write()
self._store()

else:
print("WARNING: NO SUBTOMOGRAM ESTRACTED FOR TOMOGRAM " + tomo.getTsId() + "IMPOSSIBLE TO STUDY " +
"MISALIGNMENT!")
self.info("WARNING: NO SUBTOMOGRAM ESTRACTED FOR TOMOGRAM " + tomo.getTsId() + " IMPOSSIBLE TO STUDY " +
"MISALIGNMENT!")

def closeOutputSetsStep(self):
if self.alignedTomograms:
Expand All @@ -272,8 +311,9 @@ def closeOutputSetsStep(self):
self.strongMisalignedTomograms.setStreamState(Set.STREAM_CLOSED)
self.strongMisalignedTomograms.write()

self.outputSubtomos.setStreamState(Set.STREAM_CLOSED)
self.outputSubtomos.write()
if self.outputSubtomos:
self.outputSubtomos.setStreamState(Set.STREAM_CLOSED)
self.outputSubtomos.write()

self._store()

Expand All @@ -293,23 +333,34 @@ def getTomoDict(self):
return tomoDict

def addTomoToOutput(self, tomo, overallPrediction):
if overallPrediction == 1: # Strong misali
self.getOutputSetOfStrongMisalignedTomograms()
self.strongMisalignedTomograms.append(tomo)
self.strongMisalignedTomograms.write()
self._store()
self.info("Adding tomogram %s to set %d" % (tomo.getObjId(), overallPrediction))

elif overallPrediction == 2: # Weak misali
self.getOutputSetOfWeakMisalignedTomograms()
self.weakMisalignedTomograms.append(tomo)
self.weakMisalignedTomograms.write()
self._store()
try:
if overallPrediction == 1: # Strong misali
self.getOutputSetOfStrongMisalignedTomograms()
self.strongMisalignedTomograms.append(tomo)
self.strongMisalignedTomograms.write()
self._store()

elif overallPrediction == 3: # Ali
self.getOutputSetOfAlignedTomograms()
self.alignedTomograms.append(tomo)
self.alignedTomograms.write()
self._store()
elif overallPrediction == 2: # Weak misali
self.getOutputSetOfWeakMisalignedTomograms()
self.weakMisalignedTomograms.append(tomo)
self.weakMisalignedTomograms.write()
self._store()

elif overallPrediction == 3: # Ali
self.getOutputSetOfAlignedTomograms()
self.alignedTomograms.append(tomo)
self.alignedTomograms.write()
self._store()

except Exception as e:
if "UNIQUE" in str(e):
self.info("Error adding tomogram %s to set %d. It might be already added to set (duplicated id)" %
(tomo.getObjId(), overallPrediction))
else:
self.error("Error adding tomogram %s to set %d." % (tomo.getObjId(), overallPrediction))
self.error(str(e))

@staticmethod
def readPredictionArrays(outputSubtomoXmdFilePath):
Expand Down Expand Up @@ -392,7 +443,7 @@ def getOutputSetOfSubtomos(self):
outputSubtomos = self._createSetOfSubTomograms(suffix="FS")

outputSubtomos.copyInfo(self.isot)
outputSubtomos.setSamplingRate(TARGET_SAMPLING_RATE)
outputSubtomos.setSamplingRate(self.targetSamplingRate)

outputSubtomos.setStreamState(Set.STREAM_OPEN)

Expand Down
64 changes: 64 additions & 0 deletions xmipptomo/tests/test_protocols_apply_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# **************************************************************************
# *
# * Authors: Federico P. de Isidro-Gomez
# *
# * [1] Centro Nacional de Biotecnologia, CSIC, Spain
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address '[email protected]'
# *
# **************************************************************************


from pyworkflow.tests import BaseTest, setupTestProject
from tomo.tests import DataSet

from xmipptomo.protocols import XmippProtApplyTransformSubtomo, XmippProtPhantomSubtomo


class TestXmipptomoApplyTransf(BaseTest):
"""This class check if the protocol apply_alignment_subtomo works properly."""

@classmethod
def setUpClass(cls):
setupTestProject(cls)
cls.dataset = DataSet.getDataSet('tomo-em')

def _runPreviousProtocols(self):
protPhantom = self.newProtocol(XmippProtPhantomSubtomo, option=1, nsubtomos=5)
self.launchProtocol(protPhantom)
self.assertIsNotNone(protPhantom.outputSubtomograms,
"There was a problem with subtomograms output")
return protPhantom

def _applyAlignment(self):
protPhantom = self._runPreviousProtocols()
apply = self.newProtocol(XmippProtApplyTransformSubtomo,
inputSubtomograms=protPhantom.outputSubtomograms)
self.launchProtocol(apply)
self.assertIsNotNone(apply.outputSubtomograms,
"There was a problem with subtomograms output")
self.assertIsNotNone(apply.outputAverage,
"There was a problem with average output")
return apply

def test_applyAlignment(self):
align = self._applyAlignment()
self.assertTrue(getattr(align, 'outputSubtomograms'))
self.assertTrue(getattr(align, 'outputAverage'))
return align
Loading

0 comments on commit dab0362

Please sign in to comment.