diff --git a/django_countries/filters.py b/django_countries/filters.py index eb5a150..2f9d227 100644 --- a/django_countries/filters.py +++ b/django_countries/filters.py @@ -13,13 +13,20 @@ class CountryFilter(admin.FieldListFilter): title = _("Country") # type: ignore def expected_parameters(self): + if self.field.multiple: + return [f"{self.field.name}__contains"] return [self.field.name] def choices(self, changelist): - value = self.used_parameters.get(self.field.name) + if self.field.multiple: + field_name = f"{self.field.name}__contains" + else: + field_name = self.field.name + + value = self.used_parameters.get(field_name) yield { "selected": value is None, - "query_string": changelist.get_query_string({}, [self.field.name]), + "query_string": changelist.get_query_string({}, [field_name]), "display": _("All"), } for lookup, title in self.lookup_choices(changelist): @@ -29,18 +36,18 @@ def choices(self, changelist): selected = force_str(lookup) == value yield { "selected": selected, - "query_string": changelist.get_query_string( - {self.field.name: lookup}, [] - ), + "query_string": changelist.get_query_string({field_name: lookup}, []), "display": title, } def lookup_choices(self, changelist): qs = changelist.model._default_manager.all() codes = set( - qs.distinct() - .order_by(self.field.name) - .values_list(self.field.name, flat=True) + ",".join( + qs.distinct() + .order_by(self.field.name) + .values_list(self.field.name, flat=True) + ).split(",") ) for k, v in self.field.get_choices(include_blank=False): if k in codes: diff --git a/django_countries/tests/test_admin_filters.py b/django_countries/tests/test_admin_filters.py index ce66d0f..568ad50 100644 --- a/django_countries/tests/test_admin_filters.py +++ b/django_countries/tests/test_admin_filters.py @@ -19,6 +19,13 @@ class PersonAdmin(admin.ModelAdmin): test_site.register(models.Person, PersonAdmin) +class MultiCountryAdmin(admin.ModelAdmin): + list_filter = [("countries", filters.CountryFilter)] + + +test_site.register(models.MultiCountry, MultiCountryAdmin) + + class TestCountryFilter(TestCase): def get_changelist_kwargs(self): m = self.person_admin @@ -74,3 +81,60 @@ def test_choices(self): def test_choices_empty_selection(self): return self._test_choices(selected_country_code=None) + + +class TestMultiCountryFilter(TestCase): + def get_changelist_kwargs(self): + m = self.multi_country_admin + sig = inspect.signature(ChangeList.__init__) + kwargs = {"model_admin": m} + for arg in list(sig.parameters)[2:]: + if hasattr(m, arg): + kwargs[arg] = getattr(m, arg) + return kwargs + + def setUp(self): + models.MultiCountry.objects.create(countries=["NZ"]) + models.MultiCountry.objects.create(countries=["FR", "AU"]) + models.MultiCountry.objects.create(countries=["FR", "NZ"]) + self.multi_country_admin = MultiCountryAdmin(models.MultiCountry, test_site) + + def test_filter_none(self): + request = RequestFactory().get("/multi_country/") + request.user = AnonymousUser() + cl = ChangeList(request, **self.get_changelist_kwargs()) + cl.get_results(request) + self.assertEqual(len(cl.result_list), models.MultiCountry.objects.count()) + + def test_filter_country(self): + request = RequestFactory().get("/multi_country/", data={"countries__contains": "NZ"}) + request.user = AnonymousUser() + cl = ChangeList(request, **self.get_changelist_kwargs()) + cl.get_results(request) + self.assertQuerysetEqual( + cl.result_list, models.MultiCountry.objects.exclude(countries__contains="AU"), ordered=False + ) + + def _test_choices(self, selected_country_code="NZ"): + request_params = {} + selected_country = "All" + + if selected_country_code: + request_params["countries__contains"] = selected_country_code + selected_country = countries.name(selected_country_code) + + request = RequestFactory().get("/multi_country/", data=request_params) + request.user = AnonymousUser() + cl = ChangeList(request, **self.get_changelist_kwargs()) + choices = list(cl.filter_specs[0].choices(cl)) + self.assertEqual( + [c["display"] for c in choices], ["All", "Australia", "France", "New Zealand"] + ) + for choice in choices: + self.assertEqual(choice["selected"], choice["display"] == selected_country) + + def test_choices(self): + return self._test_choices() + + def test_choices_empty_selection(self): + return self._test_choices(selected_country_code=None)