Skip to content

Commit

Permalink
Add TestGeocoder, rename internal Geocoder methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Nov 6, 2024
1 parent 0cdb613 commit d46cb2e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 22 deletions.
44 changes: 22 additions & 22 deletions tempo_embeddings/io/geocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(
"""
self.db_path = db_path
self.geolocator = Nominatim(user_agent=user_agent, timeout=timeout)
self.init_db()
self._init_db()
self.last_request_time = 0

def init_db(self) -> None:
def _init_db(self) -> None:
"""
Initializes the SQLite database.
"""
Expand All @@ -45,7 +45,7 @@ def init_db(self) -> None:
conn.commit()
conn.close()

def get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]:
def _get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]:
"""
Retrieves a cached location from the SQLite database.
Expand All @@ -65,7 +65,7 @@ def get_cached_location(self, place_name: str) -> Optional[Tuple[float, float]]:
conn.close()
return result

def cache_location(
def _cache_location(
self, place_name: str, latitude: Optional[float], longitude: Optional[float]
) -> None:
"""
Expand Down Expand Up @@ -97,26 +97,26 @@ def geocode_place(self, place_name: str) -> Optional[Tuple[float, float]]:
Returns:
Optional[Tuple[float, float]]: A tuple containing the latitude and longitude, or None if not found.
"""
cached_location = self.get_cached_location(place_name)
cached_location = self._get_cached_location(place_name)
if cached_location:
return cached_location

current_time = time.time()
elapsed_time = current_time - self.last_request_time
if elapsed_time < 1:
time.sleep(1) # Respect the rate limit of 1 request per second
self.last_request_time = time.time()

try:
location = self.geolocator.geocode(place_name)
except GeocoderServiceError as e:
logging.error(f"Geocoding request for '{place_name}' timed out: {e}")
lat, long = None, None
lat, long = cached_location
else:
if location:
lat, long = location.latitude, location.longitude
else:
current_time = time.time()
elapsed_time = current_time - self.last_request_time
if elapsed_time < 1:
time.sleep(1) # Respect the rate limit of 1 request per second
self.last_request_time = time.time()

try:
location = self.geolocator.geocode(place_name)
except GeocoderServiceError as e:
logging.error(f"Geocoding request for '{place_name}' timed out: {e}")
lat, long = None, None
self.cache_location(place_name, lat, long)
else:
if location:
lat, long = location.latitude, location.longitude
else:
lat, long = None, None
self._cache_location(place_name, lat, long)

return lat, long
66 changes: 66 additions & 0 deletions tests/test_geocoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import time
from unittest.mock import MagicMock

import pytest

from tempo_embeddings.io.geocoder import Geocoder


@pytest.fixture
def location():
return 12.34, 56.78


@pytest.fixture
def geocoder(tmp_path):
test_db_path = tmp_path / "test_geocode_cache.db"
yield Geocoder(db_path=str(test_db_path))


@pytest.fixture
def mock_geocoder(mocker, location):
mock_geocoder = mocker.patch("tempo_embeddings.io.geocoder.Nominatim.geocode")
mock_geocoder.return_value = MagicMock(latitude=location[0], longitude=location[1])
return mock_geocoder


class TestGeocoder:
def test_init(self, geocoder):
assert os.path.exists(geocoder.db_path)

def test_cache_location_and_get_cached_location(self, geocoder, location):
place_name = "Test Place"

geocoder._cache_location(place_name, location[0], location[1])
cached_location = geocoder._get_cached_location(place_name)
assert cached_location == location

def test_geocode_place(self, mock_geocoder, geocoder, location):
place_name = "Test Place"

assert geocoder.geocode_place(place_name) == location

# test caching:
geocoder.geocode_place(place_name)
(
mock_geocoder.assert_called_once_with(place_name),
"Should have called the remote service exactly once.",
)

def test_geocode_place_with_rate_limit(self, mock_geocoder, geocoder, location):
place_name = "Test Place"

geocoder.last_request_time = time.time()
start_time = time.time()
assert geocoder.geocode_place(place_name) == location

assert time.time() - start_time >= 1, "Should have delayed the remote call"
assert (
geocoder.last_request_time >= start_time + 1
), "Should have updated last_request_time"

(
mock_geocoder.assert_called_once_with(place_name),
"Should have called the remote service exactly once.",
)

0 comments on commit d46cb2e

Please sign in to comment.