From cdee01d5da0b42549e5b29365a30b1dd8ff36241 Mon Sep 17 00:00:00 2001 From: Madhav Kashyap Date: Sun, 15 Oct 2023 18:58:15 -0700 Subject: [PATCH 1/3] test(loaders): Add test for load_travel_times #125 * Add unittest coverage for loaders.py load_travel_times() * Issue -> https://github.com/uw-ssec/offshore-geodesy/issues/125 --- tests/data/config.yaml | 2 +- tests/test_loaders.py | 88 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/tests/data/config.yaml b/tests/data/config.yaml index 960ba3c..a97e8fe 100644 --- a/tests/data/config.yaml +++ b/tests/data/config.yaml @@ -35,7 +35,7 @@ solver: sound_speed: path: ./tests/data/2022/NCL1/ctd/CTD_NCL1_Ch_Mi travel_times: - path: ./tests/data/2022/NCL1/**/WG_*/pxp_tt + path: ./tests/data/2022/NCL1/**/WG_*/pxp_tt* gps_solution: path: ./tests/data/2022/NCL1/**/posfilter/POS_FREED_TRANS_TWTT # Make deletions file optional diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 8b28c4c..9a4d3ba 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,12 +1,13 @@ -from typing import Any, Dict +from typing import Any, Dict, List import pytest -from numpy import float64 -from pandas import DataFrame +from numpy import float64, isclose +from pandas import DataFrame, concat, read_csv +from pandas.api.types import is_float_dtype from gnatss.configs.main import Configuration -from gnatss.constants import SP_DEPTH, SP_SOUND_SPEED -from gnatss.loaders import load_configuration, load_sound_speed +from gnatss.constants import SP_DEPTH, SP_SOUND_SPEED, TT_DATE, TT_TIME +from gnatss.loaders import load_configuration, load_sound_speed, load_travel_times from gnatss.main import gather_files from tests import TEST_DATA_FOLDER @@ -41,3 +42,80 @@ def test_load_sound_speed(all_files_dict): assert isinstance(svdf, DataFrame) assert {SP_DEPTH, SP_SOUND_SPEED} == set(svdf.columns.values.tolist()) assert svdf.dtypes[SP_DEPTH] == float64 and svdf.dtypes[SP_SOUND_SPEED] == float64 + + +@pytest.fixture +def transponder_ids() -> List[str]: + config = load_configuration(TEST_DATA_FOLDER / "config.yaml") + transponders = config.solver.transponders + return [t.pxp_id for t in transponders] + + +@pytest.mark.parametrize( + "is_j2k, time_scale", + [ + (True, "tt"), + (False, "tt"), + ], +) +def test_load_travel_times(all_files_dict, transponder_ids, is_j2k, time_scale): + PARSED_FILE = "parsed" + + if is_j2k: + expected_columns = [TT_TIME, *transponder_ids] + loaded_travel_times = load_travel_times( + files=[file for file in all_files_dict["travel_times"] if "j2k" in file], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + raw_travel_times = concat( + [ + read_csv(i, delim_whitespace=True, header=None) + for i in all_files_dict["travel_times"] + if ((PARSED_FILE not in i) and ("j2k" in i)) + ] + ).reset_index(drop=True) + column_num_diff = len(expected_columns) - len(raw_travel_times.columns) + if column_num_diff < 0: + raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] + raw_travel_times.columns = expected_columns + + else: + expected_columns = [TT_DATE, TT_TIME, *transponder_ids] + loaded_travel_times = load_travel_times( + files=[ + file for file in all_files_dict["travel_times"] if "j2k" not in file + ], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + raw_travel_times = concat( + [ + read_csv(i, delim_whitespace=True, header=None) + for i in all_files_dict["travel_times"] + if ((PARSED_FILE not in i) and ("j2k" not in i)) + ] + ).reset_index(drop=True) + column_num_diff = len(expected_columns) - len(raw_travel_times.columns) + if column_num_diff < 0: + raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] + raw_travel_times.columns = expected_columns + raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1) + + assert isinstance(loaded_travel_times, DataFrame) + assert all( + is_float_dtype(loaded_travel_times[column]) + for column in [*transponder_ids, TT_TIME] + ) + assert loaded_travel_times.shape == raw_travel_times.shape + assert set(loaded_travel_times.columns.values.tolist()) == set( + raw_travel_times.columns.values.tolist() + ) + + # Verify microseconds to seconds conversion for delay times + for transponder_id in transponder_ids: + assert isclose( + raw_travel_times[transponder_id] * 1e-6, loaded_travel_times[transponder_id] + ).all() From 46278de8622f38f4855de1ec3ea7fa3a7fac3bf3 Mon Sep 17 00:00:00 2001 From: Madhav Kashyap Date: Sat, 21 Oct 2023 22:19:02 -0700 Subject: [PATCH 2/3] * Reverted changes to config.yaml * Split test_load_travel_times() into 2 smaller unittests --- tests/data/config.yaml | 2 +- tests/test_loaders.py | 113 ++++++++++++++++++++++++++++++----------- 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/tests/data/config.yaml b/tests/data/config.yaml index a97e8fe..960ba3c 100644 --- a/tests/data/config.yaml +++ b/tests/data/config.yaml @@ -35,7 +35,7 @@ solver: sound_speed: path: ./tests/data/2022/NCL1/ctd/CTD_NCL1_Ch_Mi travel_times: - path: ./tests/data/2022/NCL1/**/WG_*/pxp_tt* + path: ./tests/data/2022/NCL1/**/WG_*/pxp_tt gps_solution: path: ./tests/data/2022/NCL1/**/posfilter/POS_FREED_TRANS_TWTT # Make deletions file optional diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 9a4d3ba..5a6f544 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List import pytest -from numpy import float64, isclose +from numpy import float64 from pandas import DataFrame, concat, read_csv from pandas.api.types import is_float_dtype @@ -11,6 +11,8 @@ from gnatss.main import gather_files from tests import TEST_DATA_FOLDER +PARSED_FILE = "parsed" + @pytest.fixture def all_files_dict() -> Dict[str, Any]: @@ -18,6 +20,15 @@ def all_files_dict() -> Dict[str, Any]: return gather_files(config) +@pytest.fixture +def all_files_dict_j2k_travel_times() -> Dict[str, Any]: + config = load_configuration(TEST_DATA_FOLDER / "config.yaml") + config.solver.input_files.travel_times.path = ( + "./tests/data/2022/NCL1/**/WG_*/pxp_tt_j2k" + ) + return gather_files(config) + + @pytest.mark.parametrize( "config_yaml_path", [(None), (TEST_DATA_FOLDER / "invalid_config.yaml")], @@ -53,49 +64,81 @@ def transponder_ids() -> List[str]: @pytest.mark.parametrize( "is_j2k, time_scale", - [ - (True, "tt"), - (False, "tt"), - ], + [(True, "tt"), (False, "tt")], ) -def test_load_travel_times(all_files_dict, transponder_ids, is_j2k, time_scale): - PARSED_FILE = "parsed" - +def test_load_j2k_travel_times( + transponder_ids, all_files_dict_j2k_travel_times, is_j2k, time_scale +): if is_j2k: - expected_columns = [TT_TIME, *transponder_ids] loaded_travel_times = load_travel_times( - files=[file for file in all_files_dict["travel_times"] if "j2k" in file], + files=all_files_dict_j2k_travel_times["travel_times"], transponder_ids=transponder_ids, is_j2k=is_j2k, time_scale=time_scale, ) + + # raw_travel_times contains the expected df raw_travel_times = concat( [ read_csv(i, delim_whitespace=True, header=None) - for i in all_files_dict["travel_times"] - if ((PARSED_FILE not in i) and ("j2k" in i)) + for i in all_files_dict_j2k_travel_times["travel_times"] + if PARSED_FILE not in i ] ).reset_index(drop=True) + + expected_columns = [TT_TIME, *transponder_ids] column_num_diff = len(expected_columns) - len(raw_travel_times.columns) if column_num_diff < 0: raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] raw_travel_times.columns = expected_columns + # Assert that df returned from "loaded_travel_times()" matches parameters of expected df + assert isinstance(loaded_travel_times, DataFrame) + assert all( + is_float_dtype(loaded_travel_times[column]) + for column in [*transponder_ids, TT_TIME] + ) + assert loaded_travel_times.shape == raw_travel_times.shape + assert set(loaded_travel_times.columns.values.tolist()) == set( + raw_travel_times.columns.values.tolist() + ) + + # Verify microsecond to second conversion for each transponder_id column + assert loaded_travel_times[transponder_ids].equals( + raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) + ) else: + # load_travel_times() should raise Exception + # if called with is_j2k=False on j2k type travel time files + with pytest.raises(AttributeError): + _ = load_travel_times( + files=all_files_dict_j2k_travel_times["travel_times"], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + + +@pytest.mark.parametrize( + "is_j2k, time_scale", + [(True, "tt"), (False, "tt")], +) +def test_load_non_j2k_travel_times(transponder_ids, all_files_dict, is_j2k, time_scale): + if not is_j2k: expected_columns = [TT_DATE, TT_TIME, *transponder_ids] loaded_travel_times = load_travel_times( - files=[ - file for file in all_files_dict["travel_times"] if "j2k" not in file - ], + files=all_files_dict["travel_times"], transponder_ids=transponder_ids, is_j2k=is_j2k, time_scale=time_scale, ) + + # raw_travel_times contains the expected df raw_travel_times = concat( [ read_csv(i, delim_whitespace=True, header=None) for i in all_files_dict["travel_times"] - if ((PARSED_FILE not in i) and ("j2k" not in i)) + if PARSED_FILE not in i ] ).reset_index(drop=True) column_num_diff = len(expected_columns) - len(raw_travel_times.columns) @@ -104,18 +147,28 @@ def test_load_travel_times(all_files_dict, transponder_ids, is_j2k, time_scale): raw_travel_times.columns = expected_columns raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1) - assert isinstance(loaded_travel_times, DataFrame) - assert all( - is_float_dtype(loaded_travel_times[column]) - for column in [*transponder_ids, TT_TIME] - ) - assert loaded_travel_times.shape == raw_travel_times.shape - assert set(loaded_travel_times.columns.values.tolist()) == set( - raw_travel_times.columns.values.tolist() - ) + # Assert that df returned from "loaded_travel_times()" matches parameters of expected df + assert isinstance(loaded_travel_times, DataFrame) + assert all( + is_float_dtype(loaded_travel_times[column]) + for column in [*transponder_ids, TT_TIME] + ) + assert loaded_travel_times.shape == raw_travel_times.shape + assert set(loaded_travel_times.columns.values.tolist()) == set( + raw_travel_times.columns.values.tolist() + ) - # Verify microseconds to seconds conversion for delay times - for transponder_id in transponder_ids: - assert isclose( - raw_travel_times[transponder_id] * 1e-6, loaded_travel_times[transponder_id] - ).all() + # Verify microsecond to second conversion for each transponder_id column + assert loaded_travel_times[transponder_ids].equals( + raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) + ) + else: + # load_travel_times() should raise Exception + # if called with is_j2k=True on non-j2k type travel time files + with pytest.raises(TypeError): + _ = load_travel_times( + files=all_files_dict["travel_times"], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) From dc0c5ab23a36a3784f61f3aeca1596651845d48a Mon Sep 17 00:00:00 2001 From: Madhav Kashyap Date: Tue, 24 Oct 2023 17:07:55 -0700 Subject: [PATCH 3/3] * Refactor repeated code blocks into single helper function. Following DRY. --- tests/test_loaders.py | 134 +++++++++++++++++++----------------------- 1 file changed, 60 insertions(+), 74 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index c016424..cf215b3 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -75,6 +75,50 @@ def transponder_ids() -> List[str]: return [t.pxp_id for t in transponders] +def _load_travel_times_pass_testcase_helper( + expected_columns, travel_times, transponder_ids, is_j2k, time_scale +): + loaded_travel_times = load_travel_times( + files=travel_times, + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + + # raw_travel_times contains the expected df + raw_travel_times = concat( + [ + read_csv(i, delim_whitespace=True, header=None) + for i in travel_times + if PARSED_FILE not in i + ] + ).reset_index(drop=True) + + column_num_diff = len(expected_columns) - len(raw_travel_times.columns) + if column_num_diff < 0: + raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] + raw_travel_times.columns = expected_columns + + if not is_j2k: + raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1) + + # Assert that df returned from "loaded_travel_times()" matches parameters of expected df + assert isinstance(loaded_travel_times, DataFrame) + assert all( + is_float_dtype(loaded_travel_times[column]) + for column in [*transponder_ids, TT_TIME] + ) + assert loaded_travel_times.shape == raw_travel_times.shape + assert set(loaded_travel_times.columns.values.tolist()) == set( + raw_travel_times.columns.values.tolist() + ) + + # Verify microsecond to second conversion for each transponder_id column + assert loaded_travel_times[transponder_ids].equals( + raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) + ) + + @pytest.mark.parametrize( "is_j2k, time_scale", [(True, "tt"), (False, "tt")], @@ -93,42 +137,13 @@ def test_load_j2k_travel_times( time_scale=time_scale, ) else: - loaded_travel_times = load_travel_times( - files=all_files_dict_j2k_travel_times["travel_times"], - transponder_ids=transponder_ids, - is_j2k=is_j2k, - time_scale=time_scale, - ) - - # raw_travel_times contains the expected df - raw_travel_times = concat( - [ - read_csv(i, delim_whitespace=True, header=None) - for i in all_files_dict_j2k_travel_times["travel_times"] - if PARSED_FILE not in i - ] - ).reset_index(drop=True) - expected_columns = [TT_TIME, *transponder_ids] - column_num_diff = len(expected_columns) - len(raw_travel_times.columns) - if column_num_diff < 0: - raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] - raw_travel_times.columns = expected_columns - - # Assert that df returned from "loaded_travel_times()" matches parameters of expected df - assert isinstance(loaded_travel_times, DataFrame) - assert all( - is_float_dtype(loaded_travel_times[column]) - for column in [*transponder_ids, TT_TIME] - ) - assert loaded_travel_times.shape == raw_travel_times.shape - assert set(loaded_travel_times.columns.values.tolist()) == set( - raw_travel_times.columns.values.tolist() - ) - - # Verify microsecond to second conversion for each transponder_id column - assert loaded_travel_times[transponder_ids].equals( - raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) + _load_travel_times_pass_testcase_helper( + expected_columns, + all_files_dict_j2k_travel_times["travel_times"], + transponder_ids, + is_j2k, + time_scale, ) @@ -137,45 +152,7 @@ def test_load_j2k_travel_times( [(True, "tt"), (False, "tt")], ) def test_load_non_j2k_travel_times(transponder_ids, all_files_dict, is_j2k, time_scale): - if not is_j2k: - expected_columns = [TT_DATE, TT_TIME, *transponder_ids] - loaded_travel_times = load_travel_times( - files=all_files_dict["travel_times"], - transponder_ids=transponder_ids, - is_j2k=is_j2k, - time_scale=time_scale, - ) - - # raw_travel_times contains the expected df - raw_travel_times = concat( - [ - read_csv(i, delim_whitespace=True, header=None) - for i in all_files_dict["travel_times"] - if PARSED_FILE not in i - ] - ).reset_index(drop=True) - column_num_diff = len(expected_columns) - len(raw_travel_times.columns) - if column_num_diff < 0: - raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] - raw_travel_times.columns = expected_columns - raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1) - - # Assert that df returned from "loaded_travel_times()" matches parameters of expected df - assert isinstance(loaded_travel_times, DataFrame) - assert all( - is_float_dtype(loaded_travel_times[column]) - for column in [*transponder_ids, TT_TIME] - ) - assert loaded_travel_times.shape == raw_travel_times.shape - assert set(loaded_travel_times.columns.values.tolist()) == set( - raw_travel_times.columns.values.tolist() - ) - - # Verify microsecond to second conversion for each transponder_id column - assert loaded_travel_times[transponder_ids].equals( - raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) - ) - else: + if is_j2k: # load_travel_times() should raise Exception # if called with is_j2k=True on non-j2k type travel time files with pytest.raises(TypeError): @@ -185,6 +162,15 @@ def test_load_non_j2k_travel_times(transponder_ids, all_files_dict, is_j2k, time is_j2k=is_j2k, time_scale=time_scale, ) + else: + expected_columns = [TT_DATE, TT_TIME, *transponder_ids] + _load_travel_times_pass_testcase_helper( + expected_columns, + all_files_dict["travel_times"], + transponder_ids, + is_j2k, + time_scale, + ) @pytest.mark.parametrize(