From 69fc8d33b6ac514d0f35e6e90379dc420022aed0 Mon Sep 17 00:00:00 2001 From: Madhav Mahesh Kashyap <29497860+madhavmk@users.noreply.github.com> Date: Thu, 26 Oct 2023 09:52:40 -0700 Subject: [PATCH] test(loaders): Add test for load_travel_times #125 (#169) * test(loaders): Add test for load_travel_times #125 * Add unittest coverage for loaders.py load_travel_times() * Issue -> https://github.com/uw-ssec/offshore-geodesy/issues/125 * * Reverted changes to config.yaml * Split test_load_travel_times() into 2 smaller unittests * * Refactor repeated code blocks into single helper function. Following DRY. --- tests/test_loaders.py | 137 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 133 insertions(+), 4 deletions(-) diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 5465e03..cf215b3 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -1,16 +1,31 @@ -from typing import Any, Dict +from typing import Any, Dict, List import pandas as pd import pytest -from pandas import DataFrame, read_csv +from pandas import DataFrame, concat, read_csv from pandas.api.types import is_float_dtype from gnatss.configs.main import Configuration -from gnatss.constants import GPS_COV, GPS_GEOCENTRIC, GPS_TIME, SP_DEPTH, SP_SOUND_SPEED -from gnatss.loaders import load_configuration, load_gps_solutions, load_sound_speed +from gnatss.constants import ( + GPS_COV, + GPS_GEOCENTRIC, + GPS_TIME, + SP_DEPTH, + SP_SOUND_SPEED, + TT_DATE, + TT_TIME, +) +from gnatss.loaders import ( + load_configuration, + load_gps_solutions, + load_sound_speed, + load_travel_times, +) from gnatss.main import gather_files from tests import TEST_DATA_FOLDER +PARSED_FILE = "parsed" + @pytest.fixture def all_files_dict() -> Dict[str, Any]: @@ -18,6 +33,15 @@ def all_files_dict() -> Dict[str, Any]: return gather_files(config) +@pytest.fixture +def all_files_dict_j2k_travel_times() -> Dict[str, Any]: + config = load_configuration(TEST_DATA_FOLDER / "config.yaml") + config.solver.input_files.travel_times.path = ( + "./tests/data/2022/NCL1/**/WG_*/pxp_tt_j2k" + ) + return gather_files(config) + + @pytest.mark.parametrize( "config_yaml_path", [None, TEST_DATA_FOLDER / "invalid_config.yaml"], @@ -44,6 +68,111 @@ def test_load_sound_speed(all_files_dict): assert is_float_dtype(svdf[SP_DEPTH]) and is_float_dtype(svdf[SP_SOUND_SPEED]) +@pytest.fixture +def transponder_ids() -> List[str]: + config = load_configuration(TEST_DATA_FOLDER / "config.yaml") + transponders = config.solver.transponders + return [t.pxp_id for t in transponders] + + +def _load_travel_times_pass_testcase_helper( + expected_columns, travel_times, transponder_ids, is_j2k, time_scale +): + loaded_travel_times = load_travel_times( + files=travel_times, + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + + # raw_travel_times contains the expected df + raw_travel_times = concat( + [ + read_csv(i, delim_whitespace=True, header=None) + for i in travel_times + if PARSED_FILE not in i + ] + ).reset_index(drop=True) + + column_num_diff = len(expected_columns) - len(raw_travel_times.columns) + if column_num_diff < 0: + raw_travel_times = raw_travel_times.iloc[:, :column_num_diff] + raw_travel_times.columns = expected_columns + + if not is_j2k: + raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1) + + # Assert that df returned from "loaded_travel_times()" matches parameters of expected df + assert isinstance(loaded_travel_times, DataFrame) + assert all( + is_float_dtype(loaded_travel_times[column]) + for column in [*transponder_ids, TT_TIME] + ) + assert loaded_travel_times.shape == raw_travel_times.shape + assert set(loaded_travel_times.columns.values.tolist()) == set( + raw_travel_times.columns.values.tolist() + ) + + # Verify microsecond to second conversion for each transponder_id column + assert loaded_travel_times[transponder_ids].equals( + raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6) + ) + + +@pytest.mark.parametrize( + "is_j2k, time_scale", + [(True, "tt"), (False, "tt")], +) +def test_load_j2k_travel_times( + transponder_ids, all_files_dict_j2k_travel_times, is_j2k, time_scale +): + if not is_j2k: + # load_travel_times() should raise Exception + # if called with is_j2k=False on j2k type travel time files + with pytest.raises(AttributeError): + _ = load_travel_times( + files=all_files_dict_j2k_travel_times["travel_times"], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + else: + expected_columns = [TT_TIME, *transponder_ids] + _load_travel_times_pass_testcase_helper( + expected_columns, + all_files_dict_j2k_travel_times["travel_times"], + transponder_ids, + is_j2k, + time_scale, + ) + + +@pytest.mark.parametrize( + "is_j2k, time_scale", + [(True, "tt"), (False, "tt")], +) +def test_load_non_j2k_travel_times(transponder_ids, all_files_dict, is_j2k, time_scale): + if is_j2k: + # load_travel_times() should raise Exception + # if called with is_j2k=True on non-j2k type travel time files + with pytest.raises(TypeError): + _ = load_travel_times( + files=all_files_dict["travel_times"], + transponder_ids=transponder_ids, + is_j2k=is_j2k, + time_scale=time_scale, + ) + else: + expected_columns = [TT_DATE, TT_TIME, *transponder_ids] + _load_travel_times_pass_testcase_helper( + expected_columns, + all_files_dict["travel_times"], + transponder_ids, + is_j2k, + time_scale, + ) + + @pytest.mark.parametrize( "time_round", [3, 6],