diff --git a/tempo_embeddings/io/geocoder.py b/tempo_embeddings/io/geocoder.py index 312d51f..2841f07 100644 --- a/tempo_embeddings/io/geocoder.py +++ b/tempo_embeddings/io/geocoder.py @@ -24,10 +24,10 @@ def __init__( """ self.db_path = db_path self.geolocator = Nominatim(user_agent=user_agent, timeout=timeout) - self.init_db() + self._init_db() self.last_request_time = 0 - def init_db(self) -> None: + def _init_db(self) -> None: """ Initializes the SQLite database. """ @@ -45,7 +45,7 @@ def init_db(self) -> None: conn.commit() conn.close() - def get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]: + def _get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]: """ Retrieves a cached location from the SQLite database. @@ -65,7 +65,7 @@ def get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]: conn.close() return result - def cache_location( + def _cache_location( self, place_name: str, latitude: Optional[float], longitude: Optional[float] ) -> None: """ @@ -97,26 +97,26 @@ def geocode_place(self, place_name: str) -> Optional[Tuple[float, float]]: Returns: Optional[Tuple[float, float]]: A tuple containing the latitude and longitude, or None if not found. """ - cached_location = self.get_cached_location(place_name) + cached_location = self._get_cached_location(place_name) if cached_location: - return cached_location - - current_time = time.time() - elapsed_time = current_time - self.last_request_time - if elapsed_time < 1: - time.sleep(1) # Respect the rate limit of 1 request per second - self.last_request_time = time.time() - - try: - location = self.geolocator.geocode(place_name) - except GeocoderServiceError as e: - logging.error(f"Geocoding request for '{place_name}' timed out: {e}") - lat, long = None, None + lat, long = cached_location else: - if location: - lat, long = location.latitude, location.longitude - else: + current_time = time.time() + elapsed_time = current_time - self.last_request_time + if elapsed_time < 1: + time.sleep(1) # Respect the rate limit of 1 request per second + self.last_request_time = time.time() + + try: + location = self.geolocator.geocode(place_name) + except GeocoderServiceError as e: + logging.error(f"Geocoding request for '{place_name}' timed out: {e}") lat, long = None, None - self.cache_location(place_name, lat, long) + else: + if location: + lat, long = location.latitude, location.longitude + else: + lat, long = None, None + self._cache_location(place_name, lat, long) return lat, long diff --git a/tests/test_geocoder.py b/tests/test_geocoder.py new file mode 100644 index 0000000..dbb66ee --- /dev/null +++ b/tests/test_geocoder.py @@ -0,0 +1,66 @@ +import os +import time +from unittest.mock import MagicMock + +import pytest + +from tempo_embeddings.io.geocoder import Geocoder + + +@pytest.fixture +def location(): + return 12.34, 56.78 + + +@pytest.fixture +def geocoder(tmp_path): + test_db_path = tmp_path / "test_geocode_cache.db" + yield Geocoder(db_path=str(test_db_path)) + + +@pytest.fixture +def mock_geocoder(mocker, location): + mock_geocoder = mocker.patch("tempo_embeddings.io.geocoder.Nominatim.geocode") + mock_geocoder.return_value = MagicMock(latitude=location[0], longitude=location[1]) + return mock_geocoder + + +class TestGeocoder: + def test_init(self, geocoder): + assert os.path.exists(geocoder.db_path) + + def test_cache_location_and_get_cached_location(self, geocoder, location): + place_name = "Test Place" + + geocoder._cache_location(place_name, location[0], location[1]) + cached_location = geocoder._get_cached_location(place_name) + assert cached_location == location + + def test_geocode_place(self, mock_geocoder, geocoder, location): + place_name = "Test Place" + + assert geocoder.geocode_place(place_name) == location + + # test caching: + geocoder.geocode_place(place_name) + ( + mock_geocoder.assert_called_once_with(place_name), + "Should have called the remote service exactly once.", + ) + + def test_geocode_place_with_rate_limit(self, mock_geocoder, geocoder, location): + place_name = "Test Place" + + geocoder.last_request_time = time.time() + start_time = time.time() + assert geocoder.geocode_place(place_name) == location + + assert time.time() - start_time >= 1, "Should have delayed the remote call" + assert ( + geocoder.last_request_time >= start_time + 1 + ), "Should have updated last_request_time" + + ( + mock_geocoder.assert_called_once_with(place_name), + "Should have called the remote service exactly once.", + )