diff --git a/README.rst b/README.rst index 0617f7d..3468480 100644 --- a/README.rst +++ b/README.rst @@ -1,8 +1,20 @@ -======= -emtools -======= -Utilities for CryoEM data manipulation. +.. |logo_image| image:: https://github.com/3dem/emhub/wiki/images/emtools-logo.png + :height: 60px + +|logo_image| + +**emtools** is a Python package with utilities for manipulating CryoEM images +and metadata such as STAR files or SQLITE databases. It also contains other +miscellaneous utils for processes handling and monitoring, among others. + +The library is composed by several modules that provide mainly classes to +perform certain operations. + +For more detailed information check the documentation at: + +https://3dem.github.io/emdocs/emtools/ + Installation ------------ @@ -16,8 +28,7 @@ Or for development: .. code-block:: bash git clone git@github.com:3dem/emtools.git - cd emtools - pip install -e . + pip install -e emtools/ Usage ----- diff --git a/docs/conf.py b/docs/conf.py index 1b7300e..5ebba1a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,5 @@ import datetime as dt +import emtools extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosectionlabel', @@ -48,7 +49,8 @@ #html_logo = "https://github.com/3dem/emhub/wiki/images/emhub-logo-top-gray.svg" html_context = { - 'last_updated': dt.datetime.now().date() + 'last_updated': dt.datetime.now().date(), + 'emtools_version': emtools.__version__ } templates_path = ["templates"] diff --git a/docs/index.rst b/docs/index.rst index 3a64d9c..cbc6e7d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,9 +32,10 @@ Modules and Classes .. toctree:: :maxdepth: 2 + utils metadata - utils - image + jobs + diff --git a/docs/jobs/index.rst b/docs/jobs/index.rst new file mode 100644 index 0000000..d6aea97 --- /dev/null +++ b/docs/jobs/index.rst @@ -0,0 +1,10 @@ + +jobs +==== + +.. toctree:: + :maxdepth: 1 + + Pipeline + + diff --git a/docs/jobs/pipeline.rst b/docs/jobs/pipeline.rst new file mode 100644 index 0000000..6290925 --- /dev/null +++ b/docs/jobs/pipeline.rst @@ -0,0 +1,7 @@ + +Pipeline +======== + +.. autoclass:: emtools.jobs.Pipeline + :members: + diff --git a/docs/metadata/starfile.rst b/docs/metadata/starfile.rst index 5c9c4c4..5012013 100644 --- a/docs/metadata/starfile.rst +++ b/docs/metadata/starfile.rst @@ -2,10 +2,14 @@ StarFile ======== -This class allows to read and write data blocks from STAR files. -It is also designed to inspect the data blocks, columns -and number of elements in an efficient manner without parsing the entire -file. +This class allows you to read and write data blocks from STAR files. +There are some features that make this class useful for the manipulation of data in STAR format: + + +* Inspect data blocks, columns, and the number of elements without parsing the entire file. +* Iterate over the rows without reading the whole table in memory. +* Read or iterate over a subset of rows only. + Inspecting without fully parsing -------------------------------- @@ -43,6 +47,8 @@ https://github.com/3dem/em-testdata/blob/main/metadata/run_it025_data.star #>>> items: 4786 #>>> columns: 25 + sf.close() + The methods used in the previous example (`getTableNames`, `getTableSize`, and `getTableInfo`) all inspect the STAR file without fully parsing all rows. This way is much more faster @@ -50,9 +56,35 @@ that parsing rows if not needed. These methods will also create an index of wher data blocks are in the file, so if you need to read a data table, it will jump to that position in the file. +Reading a Table +--------------- + +After opening a StarFile, we can easily read any table using the function `getTable` +as shown in the following example: + +.. code-block:: python + + with StarFile(movieStar) as sf: + # Read table in a different order as they appear in file + # also before the getTableNames() call that create the offsets + t1 = sf.getTable('local_shift') + t2 = sf.getTable('general') + +This function has some optional arguments such as *guessType* for inferring +the column type from the first row. In some cases this is not desired and one +can pass *guessType=False* and then all columns will be treated as strings. +For example, reading the ``job.star`` file: + +.. code-block:: python + + with StarFile('job.star') as sf: + print(f"Tables: {sf.getTableNames()}") # ['job', 'joboptions_values'] + t = sf.getTable('joboptions_values', guessType=False) + values = {row.rlnJobOptionVariable: row.rlnJobOptionValue for row in t} -Iterating over the rows ------------------------ + +Iterating over the Table rows +----------------------------- In some cases, we just want to iterate over the rows and operate on them one by one. In that case, it is not necessary to fully load the whole table in memory. Iteration @@ -60,14 +92,154 @@ also allows to read range of rows but not all of them. This is specially useful for visualization purposes, where we can show a number of elements and allow to go through all of them in an efficient manner. +Please check the :ref:`Examples` for practical use cases. + Writing STAR files ------------------ +It is easy to write STAR files using the :class:`StarFile` class. We just need to +open it with write mode enabled. Then we could just write a table header +and then write rows one by one, or we could write an entire table at once. + +Please check the :ref:`Examples` for practical use cases. + +Examples +-------- + +Parsing EPU's XML files +....................... + +Although the :py:class:`StarFile` class has been used mainly to handle Relion STAR files, +we can use any label and table names. For example, if we want to parse the +XML files from EPU to extract the beam shift per movie, and write an output +STAR file: -Comparison with other libraries -------------------------------- +.. code-block:: python + + out = StarFile(outputStar, 'w') + t = Table(['movieBaseName', 'beamShiftX', 'beamShiftY']) + out.writeHeader('Movies', t) + + for base, x, y in EPU.get_beam_shifts(inputDir): + out.writeRow(t.Row(movieBaseName=base, beamShiftX=x, beamShiftY=y)) + + out.close() + + +Note in this example that we are not storing the whole table in memory. We just +create an empty table with the desired columns and then we write one row for +each XML file parsed. + + +Balancing Particles views based on orientation angles +..................................................... + +We could read angle Rot and Tilt from a particles STAR file as numpy arrays: + +.. code-block:: python + + with StarFile('particles.star') as sf: + size = sf.getTableSize('particles') + info = sf.getTableInfo('particles') + # Initialize the numpy arrays with zero and the number of particles + anglesRot = np.zeros(size) + anglesTilt = np.zeros(size) + # Then iterate the rows and store only these values + for i, p in enumerate(sf.iterTable('particles')): + anglesRot[i] = p.rlnAngleRot + anglesTilt[i] = p.rlnAngleTilt + + +Then we can use these arrays to plot the values and assess angular regions +more dense and create a subset of points to make it more evenly distributed. +Let's assume we computed the list of points to remove in the list *to_remove*. +Now, we can go through the input *particles.star* and we will create a similar +one, but with some particles removed. We will copy every table into the output +STAR files, except for the *particles* one, were whe need to filter out some +particles. We can do it with the following code: + +.. code-block:: python + + with StarFile('particles.star') as sf: + with StarFile('output_particles.star', 'w') as outSf: + # Preserve all tables, except particles that will be a subset + for tableName in sf.getTableNames(): + if tableName == 'particles': + info = sf.getTableInfo('particles') + table = Table(columns=info.getColumns()) + outSf.writeHeader('particles', table) + counter = 0 + for i, p in enumerate(sf.iterTable('particles')): + if i == to_remove[counter]: # Skip this item + counter += 1 + continue + outSf.writeRow(p) + else: + table = sf.getTable(tableName) + outSf.writeTable(tableName, table) + +Converting from Scipion to micrographs STAR file +................................................ + +The following function shows how we can write a *micrographs.star* file +from a Scipion set of CTFs: + +.. code-block:: python + def write_micrographs_star(micStarFn, ctfs): + firstCtf = ctfs.getFirstItem() + firstMic = firstCtf.getMicrograph() + acq = firstMic.getAcquisition() + + with StarFile(micStarFn, 'w') as sf: + optics = Table(['rlnOpticsGroupName', + 'rlnOpticsGroup', + 'rlnMicrographOriginalPixelSize', + 'rlnVoltage', + 'rlnSphericalAberration', + 'rlnAmplitudeContrast', + 'rlnMicrographPixelSize']) + ps = firstMic.getSamplingRate() + op = 1 + opName = f"opticsGroup{op}" + optics.addRowValues(opName, op, ps, + acq.getVoltage(), + acq.getSphericalAberration(), + acq.getAmplitudeContrast(), + ps) + + sf.writeLine("# version 30001") + sf.writeTable('optics', optics) + + mics = Table(['rlnMicrographName', + 'rlnOpticsGroup', + 'rlnCtfImage', + 'rlnDefocusU', + 'rlnDefocusV', + 'rlnCtfAstigmatism', + 'rlnDefocusAngle', + 'rlnCtfFigureOfMerit', + 'rlnCtfMaxResolution', + 'rlnMicrographMovieName']) + sf.writeLine("# version 30001") + sf.writeHeader('micrographs', mics) + + for ctf in ctfs: + mic = ctf.getMicrograph() + u, v, a = ctf.getDefocus() + micName = mic.getMicName() + movName = os.path.join('data', 'Images-Disc1', + micName.replace('_Data_FoilHole_', + '/Data/FoilHole_')) + row = mics.Row(mic.getFileName(), op, + ctf.getPsdFile(), + u, v, abs(u - v), a, + ctf.getFitQuality(), + ctf.getResolution(), + movName) + + sf.writeRow(row) Reference --------- diff --git a/docs/templates/sidebar/brand.html b/docs/templates/sidebar/brand.html new file mode 100644 index 0000000..3304a32 --- /dev/null +++ b/docs/templates/sidebar/brand.html @@ -0,0 +1,34 @@ +{#- + +Hi there! + +You might be interested in https://pradyunsg.me/furo/customisation/sidebar/ + +Although if you're reading this, chances are that you're either familiar +enough with Sphinx that you know what you're doing, or landed here from that +documentation page. + +Hope your day's going well. :) + +-#} + diff --git a/docs/utils.rst b/docs/utils.rst deleted file mode 100644 index f5af62a..0000000 --- a/docs/utils.rst +++ /dev/null @@ -1,5 +0,0 @@ - -utils -===== - -Miscellaneous utilities diff --git a/docs/utils/index.rst b/docs/utils/index.rst new file mode 100644 index 0000000..1dc4e00 --- /dev/null +++ b/docs/utils/index.rst @@ -0,0 +1,11 @@ + +utils +===== + +.. toctree:: + :maxdepth: 1 + + Color, Pretty, Timer + Process, Path, System + + diff --git a/docs/utils/misc.rst b/docs/utils/misc.rst new file mode 100644 index 0000000..c5933bc --- /dev/null +++ b/docs/utils/misc.rst @@ -0,0 +1,18 @@ + +Color +===== + +.. autoclass:: emtools.utils.Color + :members: + +Pretty +====== + +.. autoclass:: emtools.utils.Pretty + :members: + +Timer +===== + +.. autoclass:: emtools.utils.Timer + :members: diff --git a/docs/utils/process.rst b/docs/utils/process.rst new file mode 100644 index 0000000..60f2e14 --- /dev/null +++ b/docs/utils/process.rst @@ -0,0 +1,19 @@ + +Process +======= + +.. autoclass:: emtools.utils.Process + :members: + +Path +==== + +.. autoclass:: emtools.utils.Path + :members: + +System +====== + +.. autoclass:: emtools.utils.System + :members: + diff --git a/emtools/__init__.py b/emtools/__init__.py index 3fc9a10..aed35e0 100644 --- a/emtools/__init__.py +++ b/emtools/__init__.py @@ -24,5 +24,5 @@ # * # ************************************************************************** -__version__ = '0.0.11' +__version__ = '0.1.0' diff --git a/emtools/image/thumbnail.py b/emtools/image/thumbnail.py index 0469eb5..0da7d7d 100644 --- a/emtools/image/thumbnail.py +++ b/emtools/image/thumbnail.py @@ -47,6 +47,7 @@ def __init__(self, **kwargs): self.scale = 1.0 self.output_format = kwargs.get('output_format', None) self.min_max = kwargs.get('min_max', None) + self.std_threshold = kwargs.get('std_threshold', 0) def __format(self, pil_img): @@ -98,15 +99,27 @@ def from_path(self, path): return encoded def from_array(self, imageArray): - # imean = imageArray.mean() - # isd = imageArray.std() + if self.min_max: iMin, iMax = self.min_max + array = imageArray else: - iMax = imageArray.max() # min(imean + 10 * isd, imageArray.max()) - iMin = imageArray.min() # max(imean - 10 * isd, imageArray.min()) - - im255 = ((imageArray - iMin) / (iMax - iMin) * 255).astype(np.uint8) + if self.std_threshold > 0: + array = np.array(imageArray) + imean = array.mean() + isd = array.std() + isdTh = self.std_threshold * isd + minTh = imean - isdTh + maxTh = imean + isdTh + array[array < minTh] = minTh + array[array > maxTh] = maxTh + else: + array = imageArray + + iMax = array.max() + iMin = array.min() + + im255 = ((array - iMin) / (iMax - iMin) * 255).astype(np.uint8) pil_img = Image.fromarray(im255) @@ -128,3 +141,30 @@ def from_mrc(self, mrc_path): return result + @staticmethod + def Micrograph(**kwargs): + """ Shortcut method with presets for Micrograph thumbail. + All settings can be overwriten with kwargs. + """ + defaults = { + 'output_format': 'base64', + 'max_size': (512, 512), + 'contrast_factor': 0.15, + 'std_threshold': 1 + } + defaults.update(kwargs) + return Thumbnail(**defaults) + + @staticmethod + def Psd(**kwargs): + """ Shortcut method with presets for PSD thumbails. + All settings can be overwriten with kwargs. + """ + defaults = { + 'output_format': 'base64', + 'max_size': (128, 128), + 'contrast_factor': 1 + } + defaults.update(kwargs) + return Thumbnail(**defaults) + diff --git a/emtools/processing/__init__.py b/emtools/jobs/__init__.py similarity index 93% rename from emtools/processing/__init__.py rename to emtools/jobs/__init__.py index 8906b08..fff9a6d 100644 --- a/emtools/processing/__init__.py +++ b/emtools/jobs/__init__.py @@ -14,3 +14,6 @@ # * # ************************************************************************** +from .pipeline import Pipeline + +__all__ = ["Pipeline"] \ No newline at end of file diff --git a/emtools/processing/__main__.py b/emtools/jobs/__main__.py similarity index 100% rename from emtools/processing/__main__.py rename to emtools/jobs/__main__.py diff --git a/emtools/processing/motioncor.py b/emtools/jobs/motioncor.py similarity index 100% rename from emtools/processing/motioncor.py rename to emtools/jobs/motioncor.py diff --git a/emtools/utils/pipeline.py b/emtools/jobs/pipeline.py similarity index 100% rename from emtools/utils/pipeline.py rename to emtools/jobs/pipeline.py diff --git a/emtools/metadata/epu.py b/emtools/metadata/epu.py index 2228413..2190866 100644 --- a/emtools/metadata/epu.py +++ b/emtools/metadata/epu.py @@ -82,15 +82,14 @@ def get_beam_shifts(xmlDir): missing_xml = [] for root, dirs, files in os.walk(xmlDir): for fn in files: - f = os.path.join(root, fn) # Check existing movies first - - if xmlFn := EPU.get_movie_xml(f): - if os.path.exists(xmlFn): - x, y = EPU.parse_beam_shifts(xmlFn) - yield os.path.basename(fn), x, y + if xmlFn := EPU.get_movie_xml(fn): + xmlPath = os.path.join(root, xmlFn) + if os.path.exists(xmlPath): + x, y = EPU.parse_beam_shifts(xmlPath) + yield fn, x, y else: - missing_xml.append(f) + missing_xml.append(xmlPath) if missing_xml: print('Missing XML files for the following movies:') diff --git a/emtools/metadata/misc.py b/emtools/metadata/misc.py index 74510d9..69d857f 100644 --- a/emtools/metadata/misc.py +++ b/emtools/metadata/misc.py @@ -185,6 +185,10 @@ def register(self, filename, stat=None): def total_files(self): return self.counters[0].total + @property + def total_size(self): + return self.counters[0].total_size + def __contains__(self, filename): fn = filename.replace(self.root, '') return fn in self._index_files diff --git a/emtools/metadata/starfile.py b/emtools/metadata/starfile.py index 0a65b3f..b5dd2af 100644 --- a/emtools/metadata/starfile.py +++ b/emtools/metadata/starfile.py @@ -22,6 +22,7 @@ __author__ = 'Jose Miguel de la Rosa Trevin, Grigory Sharov' import sys +import re from contextlib import AbstractContextManager from .table import ColumnList, Table @@ -36,6 +37,10 @@ class StarFile(AbstractContextManager): to queries table's columns or size without parsing all data rows. """ + + # Compile regex to split data lines taking into account string literals + _splitRegex = re.compile('\"[^"]*\"|[^"\s]+') + @staticmethod def printTable(table, tableName=''): w = StarFile(sys.stdout) @@ -115,15 +120,18 @@ def getTable(self, tableName, **kwargs): self._table.addRow(self.__rowFromValues(self._values)) else: for line in self._iterRowLines(): - self._table.addRow(self.__rowFromValues(line.split())) + self._table.addRow(self.__rowFromValues(self.__split_line(line))) return self._table def getTableSize(self, tableName): - """ Return the number of elements in the given table. - This method is much more efficient that parsing the table - and getting the size, if the one is only interested in the - number of elements in the table. + """ + Return the number of elements in the given table without parsing + all the rows of the table. + + If one is only interested in the number of items in a row, + this method is much more efficient that parsing all rows in + the table. """ self._loadTableInfo(tableName) if self._singleRow: @@ -159,7 +167,7 @@ def iterTable(self, tableName, **kwargs): for i, line in enumerate(self._iterRowLines()): if i >= start: c += 1 - yield self.__rowFromValues(line.split()) + yield self.__rowFromValues(self.__split_line(line)) if limit and c == limit: break @@ -173,6 +181,13 @@ def getTableRow(self, tableName, rowIndex, **kwargs): def __loadFile(self, inputFile, mode): return open(inputFile, mode) if isinstance(inputFile, str) else inputFile + def __split_line(self, line, default=[]): + """ Split a data line taking into account string literals """ + if '"' in line: + return self._splitRegex.findall(line) if line else default + + return line.split() if line else default + def _loadTableInfo(self, tableName): self._findDataLine(tableName) @@ -191,12 +206,14 @@ def _loadTableInfo(self, tableName): self._singleRow = not self._foundLoop if self._foundLoop: - values = self._line.split() if self._line else [] + values = self.__split_line(self._line) self._colNames = colNames self._values = values def __rowFromValues(self, values): + if not values: + return None try: return self._table.Row(*[t(v) for t, v in zip(self._types, values)]) except Exception as e: @@ -204,6 +221,9 @@ def __rowFromValues(self, values): print("values: ", values) raise e + def __rowFromLine(self, line): + return self.__rowFromValues(self.__split_line(line)) + def _getRow(self): """ Get the next Row, it is None when not more rows. """ result = self._row @@ -212,7 +232,7 @@ def _getRow(self): self._row = None elif result is not None: line = self._file.readline().strip() - self._row = self.__rowFromValues(line.split()) if line else None + self._row = self.__rowFromValues(self.__split_line(line)) return result @@ -298,7 +318,7 @@ def writeSingleRow(self, tableName, row): m = max([len(c) for c in row._fields]) + 5 format = "_{:<%d} {:>10}\n" % m for col, value in row._asdict().items(): - self._file.write(format.format(col, value)) + self._file.write(format.format(col, _escapeStrValue(value))) self._file.write('\n\n') def writeHeader(self, tableName, table): @@ -318,6 +338,8 @@ def _writeRowValues(self, values): """ if not self._format: self._computeLineFormat([values]) + + values = [_escapeStrValue(v) for v in values] self._file.write(self._format.format(*values)) def writeRow(self, row): @@ -378,3 +400,7 @@ def _getFormatStr(v): return '.6f' if isinstance(v, float) else '' +def _escapeStrValue(v): + """ Escape string values by adding quotes if the string + is empty or contains spaces. """ + return '"%s"' % v if isinstance(v, str) and (not v or ' ' in v) else v diff --git a/emtools/metadata/table.py b/emtools/metadata/table.py index 2b05dd0..26126f3 100644 --- a/emtools/metadata/table.py +++ b/emtools/metadata/table.py @@ -28,7 +28,7 @@ class Column: def __init__(self, name, type=None): self._name = name - self._type = type or str + self._type = type or _str def __str__(self): return 'Column: %s (type: %s)' % (self._name, self._type) @@ -128,7 +128,7 @@ def createColumns(colNames, values, guessType=True, types=None): elif guessType and values: colType = _guessType(values[i]) else: - colType = str + colType = _str columns.append(Column(colName, colType)) return columns @@ -136,7 +136,7 @@ def createColumns(colNames, values, guessType=True, types=None): class Table(ColumnList): """ - Class to hold and manipulate tabular data for EM processing programs. + Class to hold and manipulate tabular data. """ def __init__(self, columns=None): ColumnList.__init__(self, columns) @@ -229,7 +229,7 @@ def removeColumns(self, *args): oldColumns = self._columns oldRows = self._rows - # Remove non desired columns and create again the Row class + # Remove undesired columns and create again the Row class self._columns = OrderedDict([(k, v) for k, v in oldColumns.items() if k not in rmCols]) self.Row = self.createRowClass() @@ -244,8 +244,12 @@ def removeColumns(self, *args): def getColumnValues(self, colName): """ Return the values of a given column - :param colName: The name of an existing column to retrieve values. - :return: A list with all values of that column. + + Args: + colName: The name of an existing column to retrieve values. + + Return: + A list with all values of that column. """ if colName not in self._columns: raise Exception("Not existing column: %s" % colName) @@ -276,6 +280,10 @@ def __setitem__(self, key, value): # --------- Helper functions ------------------------ +def _str(s): + """ Get the string value but stripping quotes if present. """ + return s[1:-1] if s.startswith('"') and s.endswith('"') else s + def _guessType(strValue): try: @@ -286,7 +294,7 @@ def _guessType(strValue): float(strValue) return float except ValueError: - return str + return _str def _formatValue(v): diff --git a/emtools/scripts/emt-scipion-otf.py b/emtools/scripts/emt-scipion-otf.py index 6155cdf..26cb784 100755 --- a/emtools/scripts/emt-scipion-otf.py +++ b/emtools/scripts/emt-scipion-otf.py @@ -267,27 +267,11 @@ def create_project(workingDir): def _path(*p): return os.path.join(workingDir, *p) - """ - {"acquisition": {"voltage": 200, "magnification": 79000, "pixel_size": 1.044, "dose": 1.063, "cs": 2.7}} - """ - scipionOptsFn = _path('scipion_otf_options.json') relionOptsFn = _path('relion_it_options.py') - if os.path.exists(scipionOptsFn): - with open(scipionOptsFn) as f: - opts = json.load(f) - - elif os.path.exists(relionOptsFn): - with open(_path('relion_it_options.py')) as f: - relionOpts = OrderedDict(ast.literal_eval(f.read())) - opts = {'acquisition': { - 'voltage': relionOpts['prep__importmovies__kV'], - 'pixel_size': relionOpts['prep__importmovies__angpix'], - 'cs': relionOpts['prep__importmovies__Cs'], - 'magnification': 130000, - 'dose': relionOpts['prep__motioncorr__dose_per_frame'] - }} + with open(scipionOptsFn) as f: + opts = json.load(f) acq = opts['acquisition'] picking = opts.get('picking', {}) @@ -393,6 +377,11 @@ def _path(*p): wf.launchProtocol(protCryolo, wait={OUT_COORD: 100}) + skip_2d = not opts.get('2d', True) + + if skip_2d: + return + calculateBoxSize(protCryolo) protRelionExtract = wf.createProtocol( @@ -476,7 +465,7 @@ def print_protocol(workingDir, protId): if protId == 'all': for prot in project.getRuns(iterate=True): clsName = prot.getClassName() - print(f"- {prot.getObjId():>8} {prot.getStatus():<10} {clsName}") + print(f"- {prot.getObjId():>6} {prot.getStatus():<10} {clsName:<30} - {prot.getRunName()}") else: prot = project.getProtocol(int(protId)) if prot is None: @@ -571,6 +560,9 @@ def write_coordinates(micStarFn, prot): for coord in coords.iterItems(orderBy='_micId', direction='ASC'): micId = coord.getMicId() + if micId not in micDict: + continue + if micId not in micIds: micIds.add(micId) micFn = micDict[micId] @@ -598,7 +590,9 @@ def print_prot(prot, label='Protocol'): def write_stars(workingDir, ids=None): - """ Restart one or more protocols. """ + """ Write star files for Relion. Generates micrographs_ctf.star, + coordinates.star and Coordinates folder. + """ print("ids", ids) def _get_keys(tokens): @@ -616,8 +610,12 @@ def _get_keys(tokens): idsDict = {k: v for k, v in _get_keys(ids)} if 'ctfs' in idsDict: protCtf = project.getProtocol(idsDict['ctfs']) + if protCtf is None: + raise Exception(f"There is no CTF protocol with id {idsDict['ctfs']}") if 'picking' in idsDict: protPicking = project.getProtocol(idsDict['picking']) + if protPicking is None: + raise Exception(f"There is no CTF protocol with id {idsDict['picking']}") else: # Default option when running OTF that we export STAR files # from CTFFind and Cryolo runs @@ -704,7 +702,7 @@ def main(): p = argparse.ArgumentParser(prog='scipion-otf') g = p.add_mutually_exclusive_group() - g.add_argument('--create', action='store_true', + g.add_argument('--create', metavar='Scipion project path', help="Create a new Scipion project in the working " "directory. This will overwrite any existing " "'scipion' folder there.") @@ -721,6 +719,7 @@ def main(): g.add_argument('--clean', action="store_true", help="Clean Scipion project files/folders.") g.add_argument('--continue_2d', action="store_true") + g.add_argument('--write_stars', default=argparse.SUPPRESS, nargs='*', help="Generate STAR micrographs and particles STAR files." "By default, it will get the first CTFfind protocol for ctfs" @@ -737,7 +736,7 @@ def main(): args = p.parse_args() if args.create: - create_project(cwd) + create_project(args.create) elif args.restart: restart(cwd, args.restart) elif args.restart_rankers: diff --git a/emtools/scripts/emt_beamshifts.py b/emtools/scripts/emt_beamshifts.py index cb72227..bba7073 100755 --- a/emtools/scripts/emt_beamshifts.py +++ b/emtools/scripts/emt_beamshifts.py @@ -31,6 +31,8 @@ def parse(inputDir, outputStar): print("\rParsed: ", i, end="") out.writeRow(t.Row(movieBaseName=base, beamShiftX=x, beamShiftY=y)) + print() + if outputStar: out.close() diff --git a/emtools/tests/__init__.py b/emtools/tests/__init__.py index c947fc2..579554b 100644 --- a/emtools/tests/__init__.py +++ b/emtools/tests/__init__.py @@ -6,10 +6,17 @@ EM_TESTDATA = os.environ.get('EM_TESTDATA', None) +_EM_TESTDATA_WARN = True # Warn once if EM_TEST_DATA is not configured + def testpath(*paths): """ Return paths from EM_TESTDATA. """ if EM_TESTDATA is None: + global _EM_TESTDATA_WARN + if _EM_TESTDATA_WARN: + print(f">>> Warning, {Color.warn('EM_TESTDATA')} variable not " + f"defined, some test might not be executed.\n") + _EM_TESTDATA_WARN = False return None p = os.path.abspath(os.path.join(EM_TESTDATA, *paths)) diff --git a/emtools/tests/test_metadata.py b/emtools/tests/test_metadata.py index ebe3b32..84b41e4 100644 --- a/emtools/tests/test_metadata.py +++ b/emtools/tests/test_metadata.py @@ -22,6 +22,17 @@ from emtools.metadata import StarFile, SqliteFile, EPU from emtools.tests import testpath +# Try to load starfile library to launch some comparisons +try: + import starfile +except: + starfile = None + +try: + import emtable +except: + emtable = None + class TestStarFile(unittest.TestCase): """ @@ -128,10 +139,12 @@ def test_read_particlesStar(self): otable = sf.getTable('optics') self.assertEqual(len(otable), 1) + ftmp = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.star') + print(f">>>> {ftmp.name}") + def _enlarge(inputStar, n): """ Enlarge input star file. """ lines = [] - ftmp = tempfile.TemporaryFile('w+') with open(inputStar) as f: for line in f: ftmp.write(line) @@ -141,26 +154,91 @@ def _enlarge(inputStar, n): for line in lines: ftmp.write(line) ftmp.seek(0) - return ftmp n = 1000 nn = 1000 * (n + 1) - part1m = _enlarge(partStar, n) + _enlarge(partStar, n) + + ftmp.close() - with StarFile(part1m) as sf: - t = Timer() + print(Color.cyan(f"Testing starfile with {nn} particles")) + + t = Timer() + tmpStar = ftmp.name + + with StarFile(tmpStar) as sf: + t.tic() otable = sf.getTable('optics') t.toc('Read optics:') + t.tic() + size = sf.getTableSize('particles') + t.toc(f'Counted {size} particles:') + self.assertEqual(size, nn) + t.tic() ptable = sf.getTable('particles') t.toc(f'Read {len(ptable)} particles:') self.assertEqual(len(ptable), nn) + if emtable: t.tic() - size = sf.getTableSize('particles') - t.toc(f'Counted {size} particles:') - self.assertEqual(size, nn) + table = emtable.Table(fileName=tmpStar, tableName='particles') + t.toc("Read with 'emtable'") + + if starfile: + t.tic() + df = starfile.read(tmpStar) + t.toc("Read with 'starfile'") + + os.unlink(ftmp.name) + + def test_read_jobstar(self): + jobStar = testpath('metadata', 'relion5_job002_job.star') + if jobStar is None: + return + + expected_values = { + 'bfactor': '150', + 'bin_factor': '1', + 'do_dose_weighting': 'Yes', + 'do_queue': 'No', + 'dose_per_frame': '1.277', + 'eer_grouping': '32', + 'first_frame_sum': '1', + 'fn_defect': '', + 'fn_gain_ref': 'Movies/gain.mrc', + 'gain_flip': 'No flipping (0)' + } + + def _checkValues(t): + values = {row.rlnJobOptionVariable: row.rlnJobOptionValue for row in t} + for k, v in expected_values.items(): + #print(f"{k} = {v}") + self.assertEqual(v, values[k]) + + expected_tables = ['job', 'joboptions_values'] + with StarFile(jobStar) as sf: + print(f"Tables: {sf.getTableNames()}") + self.assertEqual(expected_tables, sf.getTableNames()) + t1 = sf.getTable('joboptions_values', guessType=False) + _checkValues(t1) + + # Test that we can write values + # with empty string and spaces + ftmp = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.star') + print(f">>>> Writting to {ftmp.name}") + + with StarFile(ftmp) as sf: + sf.writeTable('joboptions', t1) + + ftmp.close() + + with StarFile(ftmp.name) as sf: + t2 = sf.getTable('joboptions', guessType=False) + _checkValues(t2) + + os.unlink(ftmp.name) class TestEPU(unittest.TestCase): @@ -206,7 +284,7 @@ def test_readMovies(self): with SqliteFile(movieSqlite) as sf: self.assertEqual(sf.getTableNames(), self.BASIC_TABLES) - self.assertEqual(sf.getTableSize('Objects'), 19078) + self.assertEqual(sf.getTableSize('Objects'), 3775) props = [row for row in sf.iterTable('Properties')] self.assertEqual(len(props), 22) @@ -227,7 +305,7 @@ def test_getTableRow(self): with SqliteFile(movieSqlite) as sf: t1 = [row for row in sf.iterTable('Objects')] - for i in [0, 1, 2, 19077]: + for i in [0, 1, 2, 3774]: row = sf.getTableRow('Objects', i) self.assertEqual(row, t1[i]) @@ -247,12 +325,8 @@ def test_readParticles(self): self.assertEqual(sf.getTableNames(), self.BASIC_TABLES) t.tic() - self.assertEqual(sf.getTableSize('Objects'), 1417708) + self.assertEqual(sf.getTableSize('Objects'), 130987) t.toc("Size of particles") rows = [r for r in sf.iterTable('Classes')] - self.assertEqual(len(rows), 49) - - # for row in sf.iterTable('Objects', limit=1, classes='Classes'): - # for k, v in row.items(): - # print(f"{k:>10}: {v}") + self.assertEqual(len(rows), 45) diff --git a/emtools/tests/test_pipeline.py b/emtools/tests/test_pipeline.py index e9f308d..33687c5 100644 --- a/emtools/tests/test_pipeline.py +++ b/emtools/tests/test_pipeline.py @@ -18,7 +18,7 @@ import numpy as np import time -from emtools.utils import Pipeline +from emtools.jobs import Pipeline class TestThreading(unittest.TestCase): diff --git a/emtools/utils/__init__.py b/emtools/utils/__init__.py index 4581338..01f256c 100644 --- a/emtools/utils/__init__.py +++ b/emtools/utils/__init__.py @@ -15,15 +15,16 @@ # ************************************************************************** from .color import Color -from .process import Process from .pretty import Pretty -from .path import Path from .time import Timer -from .pipeline import Pipeline + +from .process import Process +from .path import Path from .system import System + from .server import JsonTCPServer, JsonTCPClient -__all__ = [Color, Process, Pretty, Path, Timer, Pipeline, System, - JsonTCPServer, JsonTCPClient] +__all__ = ["Color", "Pretty", "Timer", "Process", "Path", "System", + "JsonTCPServer", "JsonTCPClient"] diff --git a/emtools/utils/color.py b/emtools/utils/color.py index b06f0e3..627bc28 100644 --- a/emtools/utils/color.py +++ b/emtools/utils/color.py @@ -26,6 +26,9 @@ class Color: + """ Basic helper class to have colored string. + Useful for commands and log messages. """ + @staticmethod def green(msg): return f'{OKGREEN}{msg}{ENDC}' diff --git a/emtools/utils/path.py b/emtools/utils/path.py index 094e568..b1049be 100644 --- a/emtools/utils/path.py +++ b/emtools/utils/path.py @@ -1,18 +1,33 @@ +# ************************************************************************** +# * +# * Authors: J.M. de la Rosa Trevin (delarosatrevin@gmail.com) +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# ************************************************************************** import os import time -from glob import glob from datetime import datetime as dt -from datetime import timedelta from collections import OrderedDict -import numpy as np from .pretty import Pretty from .process import Process -from .color import Color class Path: + """ + Group some path utility functions. + """ + class ExtDict(OrderedDict): """ Keep track of number of files and size by extension. """ def register(self, filename, stat=None): @@ -64,10 +79,12 @@ def splitall(path): @staticmethod def addslash(path): + """ Add an slash (/) to the end of the path if not present. """ return path if path.endswith('/') else path + '/' @staticmethod def rmslash(path): + """ Remove the slash (/) from the end of the path if present. """ return path[:-1] if path.endswith('/') else path @staticmethod @@ -149,7 +166,7 @@ def _mkdir(d): @staticmethod def replaceExt(filename, newExt): """ Replace the current path extension(from last .) - with a new one. The new one should not contains the .""" + with a new one. The new one should not contain the .""" return Path.removeExt(filename) + '.' + newExt @staticmethod diff --git a/emtools/utils/pretty.py b/emtools/utils/pretty.py index 6dce00e..626f4c3 100644 --- a/emtools/utils/pretty.py +++ b/emtools/utils/pretty.py @@ -20,6 +20,9 @@ class Pretty: + """ Helper class for "pretty" string formatting from several input types + (e.g. size, dates, timestamps, elapsed, etc.). + """ # Default timestamp DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S' DATE_FORMAT = '%Y-%m-%d' diff --git a/emtools/utils/process.py b/emtools/utils/process.py index 8176d8d..4ff2106 100644 --- a/emtools/utils/process.py +++ b/emtools/utils/process.py @@ -25,7 +25,7 @@ class Process: def __init__(self, *args, **kwargs): - """ Create a process using subprocess.run. """ + """ Create a process using subprocess.""" self.args = args error = '' try: @@ -44,6 +44,8 @@ def __init__(self, *args, **kwargs): raise Exception(error) def lines(self): + """ Iterate over the lines of the process output. + """ for line in self.stdout.split('\n'): yield line @@ -55,6 +57,14 @@ def print(self, args=True, stdout=False): @staticmethod def system(cmd, only_print=False, color=None): + """ Execute and print a command. + + Args: + cmd: Command to be executed. + only_print: If true, the command will only be printed and + not executed + color: Optional color for the command + """ printCmd = cmd if color is None else color(cmd) print(printCmd) if not only_print: @@ -63,6 +73,7 @@ def system(cmd, only_print=False, color=None): @staticmethod def ps(program, workingDir=None, children=False): """ Inspect processes matching a given program name. + Args: program: string matching the program name workingDir: if not None, filter processes only with that folder as @@ -92,7 +103,7 @@ def _addProc(f, proc): return processes class Logger: - """ Use a logger to log command that are executed via os.system. """ + """ Use a logger to log commands that are executed via os.system. """ def __init__(self, logger=None, only_log=False, format='%(asctime)s %(levelname)s %(message)s'): # If not logger, create one using stdout @@ -115,11 +126,13 @@ def __init__(self, logger=None, only_log=False, def system(self, cmd, retry=None): """ Execute a command and log it. + Args: cmd: Command string to be executed with os.system retry: If not None, it should be the time in seconds after which the command will be re-executed on failure until successful completion. + Return: last exit_status from os.system result """ @@ -146,11 +159,11 @@ def mkdir(self, path, retry=None): self.system(f"mkdir -p '{path}'", retry=retry) def cp(self, src, dst, retry=None): - """ Make a folder path. """ + """ Copy from src to dst. """ self.system(f"cp '{src}' '{dst}'", retry=retry) def mv(self, src, dst, retry=None): - """ Make a folder path. """ + """ Move from src to dst. """ self.system(f"mv '{src}' '{dst}'", retry=retry) def rm(self, path): diff --git a/emtools/utils/system.py b/emtools/utils/system.py index fc46323..8f54c5d 100644 --- a/emtools/utils/system.py +++ b/emtools/utils/system.py @@ -23,6 +23,7 @@ import socket import platform import psutil + from .process import Process @@ -33,6 +34,7 @@ class System: @staticmethod def gpus(): + """ Return a dictionary with existing GPUs and their properties. """ gpus = [] query = Process(System.NVIDIA_SMI_QUERY[0], *System.NVIDIA_SMI_QUERY[1:], doRaise=False) @@ -69,6 +71,7 @@ def cpus(): @staticmethod def memory(): + """ Return the total virtual memory of the system. (in Gb) """ return psutil.virtual_memory().total // (1024 * 1024 * 1024) # GiB @staticmethod @@ -92,4 +95,5 @@ def specs(): @staticmethod def hostname(): + """ Return the hostname. """ return socket.gethostname() diff --git a/emtools/utils/time.py b/emtools/utils/time.py index e35f3d7..fdbe875 100644 --- a/emtools/utils/time.py +++ b/emtools/utils/time.py @@ -26,7 +26,9 @@ def __init__(self, message="Elapsed:"): self.message = message self.tic() - def tic(self): + def tic(self, msg=None): + if msg: + print(msg) self._dt = datetime.now() def getElapsedTime(self): diff --git a/setup.py b/setup.py index 83337d6..486211c 100644 --- a/setup.py +++ b/setup.py @@ -79,5 +79,9 @@ 'emt-beamshifts = emtools.scripts.emt_beamshifts:main', 'emt-angdist = emtools.scripts.emt_angdist:main' ], + }, + scripts= [ + 'emtools/scripts/emt-scipion-otf.py' + ] )