Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CountryFilter to play nice with multiple=True #445

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions django_countries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@ class CountryFilter(admin.FieldListFilter):
title = _("Country") # type: ignore

def expected_parameters(self):
if self.field.multiple:
return [f"{self.field.name}__contains"]
return [self.field.name]

def choices(self, changelist):
value = self.used_parameters.get(self.field.name)
if self.field.multiple:
field_name = f"{self.field.name}__contains"
else:
field_name = self.field.name

value = self.used_parameters.get(field_name)
yield {
"selected": value is None,
"query_string": changelist.get_query_string({}, [self.field.name]),
"query_string": changelist.get_query_string({}, [field_name]),
"display": _("All"),
}
for lookup, title in self.lookup_choices(changelist):
Expand All @@ -29,18 +36,18 @@ def choices(self, changelist):
selected = force_str(lookup) == value
yield {
"selected": selected,
"query_string": changelist.get_query_string(
{self.field.name: lookup}, []
),
"query_string": changelist.get_query_string({field_name: lookup}, []),
"display": title,
}

def lookup_choices(self, changelist):
qs = changelist.model._default_manager.all()
codes = set(
qs.distinct()
.order_by(self.field.name)
.values_list(self.field.name, flat=True)
",".join(
qs.distinct()
.order_by(self.field.name)
.values_list(self.field.name, flat=True)
).split(",")
)
for k, v in self.field.get_choices(include_blank=False):
if k in codes:
Expand Down
69 changes: 69 additions & 0 deletions django_countries/tests/test_admin_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ class PersonAdmin(admin.ModelAdmin):
test_site.register(models.Person, PersonAdmin)


class MultiCountryAdmin(admin.ModelAdmin):
list_filter = [("countries", filters.CountryFilter)]


test_site.register(models.MultiCountry, MultiCountryAdmin)


class TestCountryFilter(TestCase):
def get_changelist_kwargs(self):
m = self.person_admin
Expand Down Expand Up @@ -74,3 +81,65 @@ def test_choices(self):

def test_choices_empty_selection(self):
return self._test_choices(selected_country_code=None)


class TestMultiCountryFilter(TestCase):
def get_changelist_kwargs(self):
m = self.multi_country_admin
sig = inspect.signature(ChangeList.__init__)
kwargs = {"model_admin": m}
for arg in list(sig.parameters)[2:]:
if hasattr(m, arg):
kwargs[arg] = getattr(m, arg)
return kwargs

def setUp(self):
models.MultiCountry.objects.create(countries=["NZ"])
models.MultiCountry.objects.create(countries=["FR", "AU"])
models.MultiCountry.objects.create(countries=["FR", "NZ"])
self.multi_country_admin = MultiCountryAdmin(models.MultiCountry, test_site)

def test_filter_none(self):
request = RequestFactory().get("/multi_country/")
request.user = AnonymousUser()
cl = ChangeList(request, **self.get_changelist_kwargs())
cl.get_results(request)
self.assertEqual(len(cl.result_list), models.MultiCountry.objects.count())

def test_filter_country(self):
request = RequestFactory().get(
"/multi_country/", data={"countries__contains": "NZ"}
)
request.user = AnonymousUser()
cl = ChangeList(request, **self.get_changelist_kwargs())
cl.get_results(request)
self.assertQuerysetEqual(
cl.result_list,
models.MultiCountry.objects.exclude(countries__contains="AU"),
ordered=False,
)

def _test_choices(self, selected_country_code="NZ"):
request_params = {}
selected_country = "All"

if selected_country_code:
request_params["countries__contains"] = selected_country_code
selected_country = countries.name(selected_country_code)

request = RequestFactory().get("/multi_country/", data=request_params)
request.user = AnonymousUser()
cl = ChangeList(request, **self.get_changelist_kwargs())
choices = list(cl.filter_specs[0].choices(cl))
self.assertEqual(
[c["display"] for c in choices],
["All", "Australia", "France", "New Zealand"],
)
for choice in choices:
self.assertEqual(choice["selected"], choice["display"] == selected_country)

def test_choices(self):
return self._test_choices()

def test_choices_empty_selection(self):
return self._test_choices(selected_country_code=None)