Skip to content

Commit

Permalink
fix: get_writer errors if filename is none (MDAnalysis#4043)
Browse files Browse the repository at this point in the history
* fix: get_writer errors if filename is none
  • Loading branch information
jandom authored Mar 5, 2023
1 parent 2cad39e commit 55a5abd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
3 changes: 2 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ The rules for this file:
------------------------------------------------------------------------------

??/??/?? IAlibay, pgbarletta, mglagolev, hmacdope, manuel.nuno.melo, chrispfae,
ooprathamm, MeetB7, BFedder, v-parmar, MoSchaeffler, jbarnoud
ooprathamm, MeetB7, BFedder, v-parmar, MoSchaeffler, jbarnoud, jandom
* 2.5.0

Fixes
* Fix uninitialized `format` variable issue when calling `selections.get_writer` directly (PR #4043)
* Fix tests should use results.rmsf to avoid DeprecationWarning (Issue #3695)
* Fix EDRReader failing when parsing single-frame EDR files (Issue #3999)
* Fix element parsing from PSF files tests read via Parmed (Issue #4015)
Expand Down
7 changes: 4 additions & 3 deletions package/MDAnalysis/selections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@

from .. import _SELECTION_WRITERS

from . import base
from . import vmd
from . import pymol
from . import gromacs
from . import charmm
from . import jmol


def get_writer(filename, defaultformat):
def get_writer(filename: str, defaultformat: str) -> base.SelectionWriterBase:
"""Return a SelectionWriter for `filename` or a `defaultformat`.
Parameters
Expand All @@ -67,15 +68,15 @@ def get_writer(filename, defaultformat):
Returns
-------
SelectionWriter : `type`
SelectionWriterBase : `type`
the writer *class* for the detected format
Raises
------
:exc:`NotImplementedError`
for any format that is not defined
"""

format = None
if filename:
format = os.path.splitext(filename)[1][1:] # strip initial dot!
format = format or defaultformat # use default if no fmt from fn
Expand Down
9 changes: 9 additions & 0 deletions testsuite/MDAnalysisTests/utils/test_selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# use StringIO and NamedStream to write to memory instead to temp files
import pytest
from io import StringIO
from unittest.mock import patch

import re

Expand All @@ -34,6 +35,7 @@
from MDAnalysis.tests.datafiles import PSF, DCD

import MDAnalysis
import MDAnalysis.selections as selections
from MDAnalysis.lib.util import NamedStream


Expand Down Expand Up @@ -196,3 +198,10 @@ def _assert_selectionstring(self, namedfile):
err_msg="SPT file has wrong selection name")
assert_array_equal(indices, self.ref_indices,
err_msg="SPT indices were not written correctly")


class TestSelections:
@patch.object(selections, "_SELECTION_WRITERS", {"FOO": "BAR"})
def test_get_writer_valid(self):
writer = selections.get_writer(None, "FOO")
assert writer == "BAR"

0 comments on commit 55a5abd

Please sign in to comment.