diff --git a/clinica/pipelines/statistics_surface/_inputs.py b/clinica/pipelines/statistics_surface/_inputs.py index d889a2b76..755d07b31 100644 --- a/clinica/pipelines/statistics_surface/_inputs.py +++ b/clinica/pipelines/statistics_surface/_inputs.py @@ -26,20 +26,16 @@ def _read_and_check_tsv_file(tsv_file: PathLike) -> pd.DataFrame: tsv_data : pd.DataFrame DataFrame obtained from the file. """ - if not Path(tsv_file).exists(): - raise FileNotFoundError(f"File {tsv_file} does not exist.") - tsv_data = pd.read_csv(tsv_file, sep="\t") - if len(tsv_data.columns) < 2: - raise ValueError(f"The TSV data in {tsv_file} should have at least 2 columns.") - if tsv_data.columns[0] != TSV_FIRST_COLUMN: - raise ValueError( - f"The first column in {tsv_file} should always be {TSV_FIRST_COLUMN}." + try: + return pd.read_csv(tsv_file, sep="\t").set_index( + [TSV_FIRST_COLUMN, TSV_SECOND_COLUMN] ) - if tsv_data.columns[1] != TSV_SECOND_COLUMN: + except FileNotFoundError: + raise FileNotFoundError(f"File {tsv_file} does not exist.") + except KeyError: raise ValueError( - f"The second column in {tsv_file} should always be {TSV_SECOND_COLUMN}." + f"The TSV data should have at least two columns: {TSV_FIRST_COLUMN} and {TSV_SECOND_COLUMN}" ) - return tsv_data def _get_t1_freesurfer_custom_file_template(base_dir: PathLike) -> str: diff --git a/test/unittests/pipelines/statistics_surface/test_inputs.py b/test/unittests/pipelines/statistics_surface/test_inputs.py index af806aa86..02b204c85 100644 --- a/test/unittests/pipelines/statistics_surface/test_inputs.py +++ b/test/unittests/pipelines/statistics_surface/test_inputs.py @@ -7,30 +7,31 @@ CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -def test_read_and_check_tsv_file(tmpdir): +def test_read_and_check_tsv_file_filenotfound_error(tmpdir): from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file with pytest.raises(FileNotFoundError, match="File foo.tsv does not exist"): _read_and_check_tsv_file(Path("foo.tsv")) - df = pd.DataFrame(columns=["foo"]) - df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False) - with pytest.raises( - ValueError, match=r"The TSV data in .*foo.tsv should have at least 2 columns." - ): - _read_and_check_tsv_file(tmpdir / "foo.tsv") - df = pd.DataFrame(columns=["foo", "bar"]) + + +@pytest.mark.parametrize( + "columns", [["foo"], ["foo", "bar"], ["participant_id", "bar"]] +) +def test_read_and_check_tsv_file_data_errors(tmpdir, columns): + from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file + + df = pd.DataFrame(columns=columns) df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False) with pytest.raises( ValueError, - match=r"The first column in .*foo.tsv should always be participant_id.", - ): - _read_and_check_tsv_file(tmpdir / "foo.tsv") - df = pd.DataFrame(columns=["participant_id", "bar"]) - df.to_csv(tmpdir / "foo.tsv", sep="\t", index=False) - with pytest.raises( - ValueError, match=r"The second column in .*foo.tsv should always be session_id." + match=r"The TSV data should have at least two columns: participant_id and session_id", ): _read_and_check_tsv_file(tmpdir / "foo.tsv") + + +def test_read_and_check_tsv_file(): + from clinica.pipelines.statistics_surface._inputs import _read_and_check_tsv_file + df = _read_and_check_tsv_file(Path(CURRENT_DIR) / "data/subjects.tsv") assert len(df) == 7 - assert set(df.columns) == {"participant_id", "session_id", "group", "age", "sex"} + assert set(df.columns) == {"group", "age", "sex"}