Skip to content

Commit

Permalink
Support CIDR matching.
Browse files Browse the repository at this point in the history
  - Match exact CIDR or IP within a CIDR.
  • Loading branch information
terjekv committed Aug 12, 2024
1 parent da19a07 commit bf55f7a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
20 changes: 14 additions & 6 deletions mreg/api/v1/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mreg.models.zone import (ForwardZone, ForwardZoneDelegation, NameServer,
ReverseZone, ReverseZoneDelegation)

from netaddr import IPNetwork, AddrFormatError

mreg_log = structlog.getLogger(__name__)

OperatorList = List[str]
Expand Down Expand Up @@ -45,10 +47,16 @@
"created_at": INT_OPERATORS,
"updated_at": INT_OPERATORS,
}
class CIDRFieldFilter(filters.CharFilter):
def filter(self, qs, value):
if not value:
return qs

class CIDRFieldExactFilter(filters.CharFilter):
pass

try:
cidr = IPNetwork(value)
return qs.filter(**{f"{self.field_name}__net_contains_or_equals": str(cidr)})
except AddrFormatError:
return qs.none()

class BACnetIDFilterSet(filters.FilterSet):
class Meta:
Expand Down Expand Up @@ -266,7 +274,7 @@ class Meta:


class NetGroupRegexPermissionFilterSet(filters.FilterSet):
range = CIDRFieldExactFilter(field_name="range")
range = CIDRFieldFilter(field_name="range")

class Meta:
model = NetGroupRegexPermission
Expand All @@ -280,7 +288,7 @@ class Meta:


class NetworkFilterSet(filters.FilterSet):
network = CIDRFieldExactFilter(field_name="network")
network = CIDRFieldFilter(field_name="network")

class Meta:
model = Network
Expand Down Expand Up @@ -329,7 +337,7 @@ class Meta:


class ReverseZoneFilterSet(filters.FilterSet):
network = CIDRFieldExactFilter(field_name="network")
network = CIDRFieldFilter(field_name="network")

class Meta:
model = ReverseZone
Expand Down
31 changes: 31 additions & 0 deletions mreg/api/v1/tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hostpolicy.models import HostPolicyAtom, HostPolicyRole
from mreg.models.base import Label
from mreg.models.host import Host, Ipaddress
from mreg.models.network import NetGroupRegexPermission
from mreg.models.resource_records import Cname

from .tests import MregAPITestCase
Expand Down Expand Up @@ -233,3 +234,33 @@ def test_filtering_for_hostpolicy(self, endpoint: str, query_key: str, target: s

for obj in chain(roles, atoms, labels, hosts):
obj.delete()

@parametrize(("cidr", "exists"), [
param("10.0.0.0/24", True, id="cidr_0_true"),
param("10.0.1.0/24", True, id="cidr_1_true"),
param("10.0.2.0/24", True, id="cidr_2_true"),
param("10.0.3.0/24", False, id="cidr_3_false"),
param("10.0.0.1", True, id="ip_0_1_true"),
param("10.0.0.2", True, id="ip_0_2_true"),
param("10.0.1.1", True, id="ip_1_1_true"),
param("10.0.2.1", True, id="ip_2_1_true"),
param("10.0.3.1", False, id="ip_3_1_false"),
],
)
def test_filter_netgroup_regex_permission(self, cidr: str, exists: bool) -> None:
"""Test filtering on netgroup regex permission."""

generate_count = 3

for i in range(generate_count):
NetGroupRegexPermission.objects.create(
regex=".*",
range=f"10.0.{i}.0/24"
)

response = self.client.get(f"/api/v1/permissions/netgroupregex/?range={cidr}")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertEqual(data["count"], 1 if exists else 0)

0 comments on commit bf55f7a

Please sign in to comment.