diff --git a/django/contrib/gis/geoip2.py b/django/contrib/gis/geoip2.py index 5f49954209df..d0f3bb9fb344 100644 --- a/django/contrib/gis/geoip2.py +++ b/django/contrib/gis/geoip2.py @@ -12,6 +12,7 @@ directory corresponding to settings.GEOIP_PATH. """ +import ipaddress import socket import warnings @@ -172,10 +173,10 @@ def __repr__(self): def _check_query(self, query, city=False, city_or_country=False): "Check the query and database availability." - # Making sure a string was passed in for the query. - if not isinstance(query, str): + if not isinstance(query, (str, ipaddress.IPv4Address, ipaddress.IPv6Address)): raise TypeError( - "GeoIP query must be a string, not type %s" % type(query).__name__ + "GeoIP query must be a string or instance of IPv4Address or " + "IPv6Address, not type %s" % type(query).__name__, ) # Extra checks for the existence of country and city databases. diff --git a/docs/ref/contrib/gis/geoip2.txt b/docs/ref/contrib/gis/geoip2.txt index aca31bf78b00..1d27e3965766 100644 --- a/docs/ref/contrib/gis/geoip2.txt +++ b/docs/ref/contrib/gis/geoip2.txt @@ -107,10 +107,11 @@ and given cache setting. Querying -------- -All the following querying routines may take either a string IP address -or a fully qualified domain name (FQDN). For example, both -``'205.186.163.125'`` and ``'djangoproject.com'`` would be valid query -parameters. +All the following querying routines may take an instance of +:class:`~ipaddress.IPv4Address` or :class:`~ipaddress.IPv6Address`, a string IP +address, or a fully qualified domain name (FQDN). For example, +``IPv4Address("205.186.163.125")``, ``"205.186.163.125"``, and +``"djangoproject.com"`` would all be valid query parameters. .. method:: GeoIP2.city(query) diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index e84a27a0ec42..d458811471a6 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -59,6 +59,9 @@ Minor features * :class:`~django.contrib.gis.db.models.Collect` is now supported on MySQL 8.0.24+. +* :class:`~django.contrib.gis.geoip2.GeoIP2` now allows querying using + :class:`ipaddress.IPv4Address` or :class:`ipaddress.IPv6Address` objects. + :mod:`django.contrib.messages` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/gis_tests/test_geoip2.py b/tests/gis_tests/test_geoip2.py index 9cd5ffbdfe70..412728b3f432 100644 --- a/tests/gis_tests/test_geoip2.py +++ b/tests/gis_tests/test_geoip2.py @@ -1,3 +1,4 @@ +import ipaddress import itertools import pathlib from unittest import mock, skipUnless @@ -25,15 +26,20 @@ def build_geoip_path(*parts): ) class GeoLite2Test(SimpleTestCase): fqdn = "sky.uk" - ipv4 = "2.125.160.216" - ipv6 = "::ffff:027d:a0d8" + ipv4_str = "2.125.160.216" + ipv6_str = "::ffff:027d:a0d8" + ipv4_addr = ipaddress.ip_address(ipv4_str) + ipv6_addr = ipaddress.ip_address(ipv6_str) + query_values = (fqdn, ipv4_str, ipv6_str, ipv4_addr, ipv6_addr) @classmethod def setUpClass(cls): # Avoid referencing __file__ at module level. cls.enterClassContext(override_settings(GEOIP_PATH=build_geoip_path())) # Always mock host lookup to avoid test breakage if DNS changes. - cls.enterClassContext(mock.patch("socket.gethostbyname", return_value=cls.ipv4)) + cls.enterClassContext( + mock.patch("socket.gethostbyname", return_value=cls.ipv4_str) + ) super().setUpClass() @@ -86,7 +92,10 @@ def test_bad_query(self): functions += (g.country, g.country_code, g.country_name) values = (123, 123.45, b"", (), [], {}, set(), frozenset(), GeoIP2) - msg = "GeoIP query must be a string, not type" + msg = ( + "GeoIP query must be a string or instance of IPv4Address or IPv6Address, " + "not type" + ) for function, value in itertools.product(functions, values): with self.subTest(function=function.__qualname__, type=type(value)): with self.assertRaisesMessage(TypeError, msg): @@ -94,7 +103,7 @@ def test_bad_query(self): def test_country(self): g = GeoIP2(city="") - for query in (self.fqdn, self.ipv4, self.ipv6): + for query in self.query_values: with self.subTest(query=query): self.assertEqual( g.country(query), @@ -108,7 +117,7 @@ def test_country(self): def test_city(self): g = GeoIP2(country="") - for query in (self.fqdn, self.ipv4, self.ipv6): + for query in self.query_values: with self.subTest(query=query): self.assertEqual( g.city(query), @@ -188,15 +197,17 @@ def test_repr(self): def test_check_query(self): g = GeoIP2() - self.assertEqual(g._check_query(self.ipv4), self.ipv4) - self.assertEqual(g._check_query(self.ipv6), self.ipv6) - self.assertEqual(g._check_query(self.fqdn), self.ipv4) + self.assertEqual(g._check_query(self.fqdn), self.ipv4_str) + self.assertEqual(g._check_query(self.ipv4_str), self.ipv4_str) + self.assertEqual(g._check_query(self.ipv6_str), self.ipv6_str) + self.assertEqual(g._check_query(self.ipv4_addr), self.ipv4_addr) + self.assertEqual(g._check_query(self.ipv6_addr), self.ipv6_addr) def test_coords_deprecation_warning(self): g = GeoIP2() msg = "GeoIP2.coords() is deprecated. Use GeoIP2.lon_lat() instead." with self.assertWarnsMessage(RemovedInDjango60Warning, msg): - e1, e2 = g.coords(self.ipv4) + e1, e2 = g.coords(self.ipv4_str) self.assertIsInstance(e1, float) self.assertIsInstance(e2, float)