Skip to content

Commit

Permalink
Simplify _read_and_check_tsv_file
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasGensollen committed Aug 31, 2022
1 parent 4f9828f commit b225b69
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 27 deletions.
18 changes: 7 additions & 11 deletions clinica/pipelines/statistics_surface/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,16 @@ def _read_and_check_tsv_file(tsv_file: PathLike) -> pd.DataFrame:
tsv_data : pd.DataFrame
DataFrame obtained from the file.
"""
if not Path(tsv_file).exists():
raise FileNotFoundError(f"File {tsv_file} does not exist.")
tsv_data = pd.read_csv(tsv_file, sep="\t")
if len(tsv_data.columns) < 2:
raise ValueError(f"The TSV data in {tsv_file} should have at least 2 columns.")
if tsv_data.columns[0] != TSV_FIRST_COLUMN:
raise ValueError(
f"The first column in {tsv_file} should always be {TSV_FIRST_COLUMN}."
try:
return pd.read_csv(tsv_file, sep="\t").set_index(
[TSV_FIRST_COLUMN, TSV_SECOND_COLUMN]
)
if tsv_data.columns[1] != TSV_SECOND_COLUMN:
except FileNotFoundError:
raise FileNotFoundError(f"File {tsv_file} does not exist.")
except KeyError:
raise ValueError(
f"The second column in {tsv_file} should always be {TSV_SECOND_COLUMN}."
f"The TSV data should have at least two columns: {TSV_FIRST_COLUMN} and {TSV_SECOND_COLUMN}"
)
return tsv_data


def _get_t1_freesurfer_custom_file_template(base_dir: PathLike) -> str:
Expand Down
33 changes: 17 additions & 16 deletions test/unittests/pipelines/statistics_surface/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,31 @@
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))


def test_read_and_check_tsv_file(tmpdir):
def test_read_and_check_tsv_file_filenotfound_error(tmpdir):
from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file

with pytest.raises(FileNotFoundError, match="File foo.tsv does not exist"):
_read_and_check_tsv_file(Path("foo.tsv"))
df = pd.DataFrame(columns=["foo"])
df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False)
with pytest.raises(
ValueError, match=r"The TSV data in .*foo.tsv should have at least 2 columns."
):
_read_and_check_tsv_file(tmpdir / "foo.tsv")
df = pd.DataFrame(columns=["foo", "bar"])


@pytest.mark.parametrize(
"columns", [["foo"], ["foo", "bar"], ["participant_id", "bar"]]
)
def test_read_and_check_tsv_file_data_errors(tmpdir, columns):
from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file

df = pd.DataFrame(columns=columns)
df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False)
with pytest.raises(
ValueError,
match=r"The first column in .*foo.tsv should always be participant_id.",
):
_read_and_check_tsv_file(tmpdir / "foo.tsv")
df = pd.DataFrame(columns=["participant_id", "bar"])
df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False)
with pytest.raises(
ValueError, match=r"The second column in .*foo.tsv should always be session_id."
match=r"The TSV data should have at least two columns: participant_id and session_id",
):
_read_and_check_tsv_file(tmpdir / "foo.tsv")


def test_read_and_check_tsv_file():
from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file

df = _read_and_check_tsv_file(Path(CURRENT_DIR) / "data/subjects.tsv")
assert len(df) == 7
assert set(df.columns) == {"participant_id", "session_id", "group", "age", "sex"}
assert set(df.columns) == {"group", "age", "sex"}

0 comments on commit b225b69

Please sign in to comment.