Skip to content

Commit

Permalink
Make ipnetwork and ipaddress hashable (#148)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Yun Zheng Hu <[email protected]>
  • Loading branch information
Miauwkeru and yunzheng authored Oct 9, 2024
1 parent 4e1a285 commit c4e5641
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
55 changes: 37 additions & 18 deletions flow/record/fieldtypes/net/ip.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,75 @@
from ipaddress import ip_address, ip_network
from __future__ import annotations

from ipaddress import (
IPv4Address,
IPv4Network,
IPv6Address,
IPv6Network,
ip_address,
ip_network,
)
from typing import Union

from flow.record.base import FieldType
from flow.record.fieldtypes import defang

_IPNetwork = Union[IPv4Network, IPv6Network]
_IPAddress = Union[IPv4Address, IPv6Address]


class ipaddress(FieldType):
val = None
_type = "net.ipaddress"

def __init__(self, addr):
def __init__(self, addr: str | int | bytes):
self.val = ip_address(addr)

def __eq__(self, b):
def __eq__(self, b: str | int | bytes | _IPAddress) -> bool:
try:
return self.val == ip_address(b)
except ValueError:
return False

def __str__(self):
def __hash__(self) -> int:
return hash(self.val)

def __str__(self) -> str:
return str(self.val)

def __repr__(self):
return "{}({!r})".format(self._type, str(self))
def __repr__(self) -> str:
return f"{self._type}({str(self)!r})"

def __format__(self, spec):
def __format__(self, spec: str) -> str:
if spec == "defang":
return defang(str(self))
return str.__format__(str(self), spec)

def _pack(self):
def _pack(self) -> int:
return int(self.val)

@staticmethod
def _unpack(data):
def _unpack(data: int) -> ipaddress:
return ipaddress(data)


class ipnetwork(FieldType):
val = None
_type = "net.ipnetwork"

def __init__(self, addr):
def __init__(self, addr: str | int | bytes):
self.val = ip_network(addr)

def __eq__(self, b):
def __eq__(self, b: str | int | bytes | _IPNetwork) -> bool:
try:
return self.val == ip_network(b)
except ValueError:
return False

def __hash__(self) -> int:
return hash(self.val)

@staticmethod
def _is_subnet_of(a, b):
def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool:
try:
# Always false if one is v4 and the other is v6.
if a._version != b._version:
Expand All @@ -59,23 +78,23 @@ def _is_subnet_of(a, b):
except AttributeError:
raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b))

def __contains__(self, b):
def __contains__(self, b: str | int | bytes | _IPAddress) -> bool:
try:
return self._is_subnet_of(ip_network(b), self.val)
except (ValueError, TypeError):
return False

def __str__(self):
def __str__(self) -> str:
return str(self.val)

def __repr__(self):
return "{}({!r})".format(self._type, str(self))
def __repr__(self) -> str:
return f"{self._type}({str(self)!r})"

def _pack(self):
def _pack(self) -> str:
return self.val.compressed

@staticmethod
def _unpack(data):
def _unpack(data: str) -> ipnetwork:
return ipnetwork(data)


Expand Down
15 changes: 15 additions & 0 deletions tests/test_fieldtype_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ def test_record_ipaddress():
assert TestRecord("0.0.0.0").ip == "0.0.0.0"
assert TestRecord("192.168.0.1").ip == "192.168.0.1"
assert TestRecord("255.255.255.255").ip == "255.255.255.255"
assert hash(TestRecord("192.168.0.1").ip) == hash(net.ipaddress("192.168.0.1"))

# ipv6
assert TestRecord("::1").ip == "::1"
assert TestRecord("2001:4860:4860::8888").ip == "2001:4860:4860::8888"
assert TestRecord("2001:4860:4860::4444").ip == "2001:4860:4860::4444"

# Test whether it functions in a set
data = {TestRecord(ip).ip for ip in ["192.168.0.1", "192.168.0.1", "::1", "::1"]}
assert len(data) == 2
assert net.ipaddress("::1") in data
assert net.ipaddress("192.168.0.1") in data

# instantiate from different types
assert TestRecord(1).ip == "0.0.0.1"
assert TestRecord(0x7F0000FF).ip == "127.0.0.255"
Expand Down Expand Up @@ -90,6 +97,7 @@ def test_record_ipnetwork():
assert "192.168.1.1" not in r.subnet
assert isinstance(r.subnet, net.ipnetwork)
assert repr(r.subnet) == "net.ipnetwork('192.168.0.0/24')"
assert hash(r.subnet) == hash(net.ipnetwork("192.168.0.0/24"))

r = TestRecord("192.168.1.1/32")
assert r.subnet == "192.168.1.1"
Expand All @@ -111,6 +119,13 @@ def test_record_ipnetwork():
assert "64:ff9b::0.0.0.0" in r.subnet
assert "64:ff9b::255.255.255.255" in r.subnet

# Test whether it functions in a set
data = {TestRecord(x).subnet for x in ["192.168.0.0/24", "192.168.0.0/24", "::1", "::1"]}
assert len(data) == 2
assert net.ipnetwork("::1") in data
assert net.ipnetwork("192.168.0.0/24") in data
assert "::1" not in data


@pytest.mark.parametrize("PSelector", [Selector, CompiledSelector])
def test_selector_ipaddress(PSelector):
Expand Down

0 comments on commit c4e5641

Please sign in to comment.