Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeEfstathiadis committed Oct 16, 2023
1 parent dd129fa commit 903d4d4
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions forest/jasmine/tests/test_traj2stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from forest.jasmine.data2mobmat import great_circle_dist
from forest.jasmine.traj2stats import (Frequency, transform_point_to_circle,
gps_summaries)
Hyperparameters, gps_summaries)


@pytest.fixture()
Expand Down Expand Up @@ -273,16 +273,16 @@ def test_gps_summaries_shape(
return_value=sample_nearby_locations,
)
mocker.patch("forest.jasmine.traj2stats.locate_home", return_value=coords1)

parameters = Hyperparameters()
parameters.save_osm_log = True

summary, _ = gps_summaries(
traj=sample_trajectory,
tz_str="Europe/London",
frequency=Frequency.HOURLY,
parameters=parameters,
places_of_interest=["pub", "fast_food"],
save_osm_log=True,
threshold=None,
split_day_night=False,
person_point_radius=2,
place_point_radius=7.5,
)
assert summary.shape == (24, 21)

Expand All @@ -296,16 +296,16 @@ def test_gps_summaries_places_of_interest(
return_value=sample_nearby_locations,
)
mocker.patch("forest.jasmine.traj2stats.locate_home", return_value=coords1)

parameters = Hyperparameters()
parameters.save_osm_log = True

summary, _ = gps_summaries(
traj=sample_trajectory,
tz_str="Europe/London",
frequency=Frequency.HOURLY,
parameters=parameters,
places_of_interest=["pub", "fast_food"],
save_osm_log=True,
threshold=None,
split_day_night=False,
person_point_radius=2,
place_point_radius=7.5,
)
time_in_places_of_interest = (
summary["pub"] + summary["fast_food"] + summary["other"]
Expand All @@ -324,16 +324,16 @@ def test_gps_summaries_obs_day_night(
return_value=sample_nearby_locations,
)
mocker.patch("forest.jasmine.traj2stats.locate_home", return_value=coords1)

parameters = Hyperparameters()
parameters.save_osm_log = True

summary, _ = gps_summaries(
traj=sample_trajectory,
tz_str="Europe/London",
frequency=Frequency.DAILY,
parameters=parameters,
places_of_interest=["pub", "fast_food"],
save_osm_log=True,
threshold=None,
split_day_night=False,
person_point_radius=2,
place_point_radius=7.5,
)
total_obs = summary["obs_day"] + summary["obs_night"]
assert np.all(round(total_obs, 4) == round(summary["obs_duration"], 4))
Expand All @@ -348,16 +348,17 @@ def test_gps_summaries_datetime_nighttime_shape(
return_value=sample_nearby_locations,
)
mocker.patch("forest.jasmine.traj2stats.locate_home", return_value=coords1)

parameters = Hyperparameters()
parameters.save_osm_log = True
parameters.split_day_night = True

summary, _ = gps_summaries(
traj=sample_trajectory,
tz_str="Europe/London",
frequency=Frequency.DAILY,
parameters=parameters,
places_of_interest=["pub", "fast_food"],
save_osm_log=True,
threshold=None,
split_day_night=True,
person_point_radius=2,
place_point_radius=7.5,
)
assert summary.shape == (2, 46)

Expand All @@ -373,16 +374,16 @@ def test_gps_summaries_log_format(
return_value=sample_nearby_locations,
)
mocker.patch("forest.jasmine.traj2stats.locate_home", return_value=coords1)

parameters = Hyperparameters()
parameters.save_osm_log = True

summary, log = gps_summaries(
traj=sample_trajectory,
tz_str="Europe/London",
frequency=Frequency.DAILY,
parameters=parameters,
places_of_interest=["pub", "fast_food"],
save_osm_log=True,
threshold=None,
split_day_night=False,
person_point_radius=2,
place_point_radius=7.5,
)
dates_stats = (
summary["day"].astype(int).astype(str)
Expand Down

0 comments on commit 903d4d4

Please sign in to comment.