Skip to content

Commit

Permalink
test(loaders): Add test for load_travel_times #125 (#169)
Browse files Browse the repository at this point in the history
* test(loaders): Add test for load_travel_times #125
* Add unittest coverage for loaders.py load_travel_times()
* Issue -> #125

* * Reverted changes to config.yaml
* Split test_load_travel_times() into 2 smaller unittests

* * Refactor repeated code blocks into single helper function. Following DRY.
  • Loading branch information
madhavmk authored Oct 26, 2023
1 parent c0a7e30 commit 69fc8d3
Showing 1 changed file with 133 additions and 4 deletions.
137 changes: 133 additions & 4 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,47 @@
from typing import Any, Dict
from typing import Any, Dict, List

import pandas as pd
import pytest
from pandas import DataFrame, read_csv
from pandas import DataFrame, concat, read_csv
from pandas.api.types import is_float_dtype

from gnatss.configs.main import Configuration
from gnatss.constants import GPS_COV, GPS_GEOCENTRIC, GPS_TIME, SP_DEPTH, SP_SOUND_SPEED
from gnatss.loaders import load_configuration, load_gps_solutions, load_sound_speed
from gnatss.constants import (
GPS_COV,
GPS_GEOCENTRIC,
GPS_TIME,
SP_DEPTH,
SP_SOUND_SPEED,
TT_DATE,
TT_TIME,
)
from gnatss.loaders import (
load_configuration,
load_gps_solutions,
load_sound_speed,
load_travel_times,
)
from gnatss.main import gather_files
from tests import TEST_DATA_FOLDER

PARSED_FILE = "parsed"


@pytest.fixture
def all_files_dict() -> Dict[str, Any]:
config = load_configuration(TEST_DATA_FOLDER / "config.yaml")
return gather_files(config)


@pytest.fixture
def all_files_dict_j2k_travel_times() -> Dict[str, Any]:
config = load_configuration(TEST_DATA_FOLDER / "config.yaml")
config.solver.input_files.travel_times.path = (
"./tests/data/2022/NCL1/**/WG_*/pxp_tt_j2k"
)
return gather_files(config)


@pytest.mark.parametrize(
"config_yaml_path",
[None, TEST_DATA_FOLDER / "invalid_config.yaml"],
Expand All @@ -44,6 +68,111 @@ def test_load_sound_speed(all_files_dict):
assert is_float_dtype(svdf[SP_DEPTH]) and is_float_dtype(svdf[SP_SOUND_SPEED])


@pytest.fixture
def transponder_ids() -> List[str]:
config = load_configuration(TEST_DATA_FOLDER / "config.yaml")
transponders = config.solver.transponders
return [t.pxp_id for t in transponders]


def _load_travel_times_pass_testcase_helper(
expected_columns, travel_times, transponder_ids, is_j2k, time_scale
):
loaded_travel_times = load_travel_times(
files=travel_times,
transponder_ids=transponder_ids,
is_j2k=is_j2k,
time_scale=time_scale,
)

# raw_travel_times contains the expected df
raw_travel_times = concat(
[
read_csv(i, delim_whitespace=True, header=None)
for i in travel_times
if PARSED_FILE not in i
]
).reset_index(drop=True)

column_num_diff = len(expected_columns) - len(raw_travel_times.columns)
if column_num_diff < 0:
raw_travel_times = raw_travel_times.iloc[:, :column_num_diff]
raw_travel_times.columns = expected_columns

if not is_j2k:
raw_travel_times = raw_travel_times.drop([TT_DATE], axis=1)

# Assert that df returned from "loaded_travel_times()" matches parameters of expected df
assert isinstance(loaded_travel_times, DataFrame)
assert all(
is_float_dtype(loaded_travel_times[column])
for column in [*transponder_ids, TT_TIME]
)
assert loaded_travel_times.shape == raw_travel_times.shape
assert set(loaded_travel_times.columns.values.tolist()) == set(
raw_travel_times.columns.values.tolist()
)

# Verify microsecond to second conversion for each transponder_id column
assert loaded_travel_times[transponder_ids].equals(
raw_travel_times[transponder_ids].apply(lambda x: x * 1e-6)
)


@pytest.mark.parametrize(
"is_j2k, time_scale",
[(True, "tt"), (False, "tt")],
)
def test_load_j2k_travel_times(
transponder_ids, all_files_dict_j2k_travel_times, is_j2k, time_scale
):
if not is_j2k:
# load_travel_times() should raise Exception
# if called with is_j2k=False on j2k type travel time files
with pytest.raises(AttributeError):
_ = load_travel_times(
files=all_files_dict_j2k_travel_times["travel_times"],
transponder_ids=transponder_ids,
is_j2k=is_j2k,
time_scale=time_scale,
)
else:
expected_columns = [TT_TIME, *transponder_ids]
_load_travel_times_pass_testcase_helper(
expected_columns,
all_files_dict_j2k_travel_times["travel_times"],
transponder_ids,
is_j2k,
time_scale,
)


@pytest.mark.parametrize(
"is_j2k, time_scale",
[(True, "tt"), (False, "tt")],
)
def test_load_non_j2k_travel_times(transponder_ids, all_files_dict, is_j2k, time_scale):
if is_j2k:
# load_travel_times() should raise Exception
# if called with is_j2k=True on non-j2k type travel time files
with pytest.raises(TypeError):
_ = load_travel_times(
files=all_files_dict["travel_times"],
transponder_ids=transponder_ids,
is_j2k=is_j2k,
time_scale=time_scale,
)
else:
expected_columns = [TT_DATE, TT_TIME, *transponder_ids]
_load_travel_times_pass_testcase_helper(
expected_columns,
all_files_dict["travel_times"],
transponder_ids,
is_j2k,
time_scale,
)


@pytest.mark.parametrize(
"time_round",
[3, 6],
Expand Down

0 comments on commit 69fc8d3

Please sign in to comment.