diff --git a/examples/http-proxy/http_proxy.py b/examples/http-proxy/http_proxy.py index 8802b0e16..e49adfa02 100644 --- a/examples/http-proxy/http_proxy.py +++ b/examples/http-proxy/http_proxy.py @@ -201,6 +201,8 @@ def still_starting(): await asyncio.sleep(5) cnt += 1 + await network.remove() + if __name__ == "__main__": parser = build_parser("An extremely simple http proxy") diff --git a/examples/ssh/ssh.py b/examples/ssh/ssh.py index b9c7b9fda..25683e514 100755 --- a/examples/ssh/ssh.py +++ b/examples/ssh/ssh.py @@ -87,25 +87,26 @@ async def main(subnet_tag, payment_driver=None, payment_network=None): ) network = await golem.create_network("192.168.0.1/24") - cluster = await golem.run_service(SshService, network=network, num_instances=2) + async with network: + cluster = await golem.run_service(SshService, network=network, num_instances=2) - def instances(): - return [f"{s.provider_name}: {s.state.value}" for s in cluster.instances] + def instances(): + return [f"{s.provider_name}: {s.state.value}" for s in cluster.instances] - while True: - print(instances()) - try: - await asyncio.sleep(5) - except (KeyboardInterrupt, asyncio.CancelledError): - break + while True: + print(instances()) + try: + await asyncio.sleep(5) + except (KeyboardInterrupt, asyncio.CancelledError): + break - cluster.stop() + cluster.stop() - cnt = 0 - while cnt < 3 and any(s.is_available for s in cluster.instances): - print(instances()) - await asyncio.sleep(5) - cnt += 1 + cnt = 0 + while cnt < 3 and any(s.is_available for s in cluster.instances): + print(instances()) + await asyncio.sleep(5) + cnt += 1 if __name__ == "__main__": diff --git a/tests/factories/network.py b/tests/factories/network.py index eb37612da..ea3c79fdb 100644 --- a/tests/factories/network.py +++ b/tests/factories/network.py @@ -21,6 +21,7 @@ def _create(cls, model_class, *args, **kwargs): net_api.create_network = mock.AsyncMock( return_value=faker.Faker().binary(length=16).hex() ) + net_api.remove_network = mock.AsyncMock() kwargs["net_api"] = net_api # we're using `futures.ThreadPoolExecutor` here diff --git a/tests/test_network.py b/tests/test_network.py index 85391d790..df39c5092 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -2,118 +2,154 @@ import sys from unittest import mock -from yapapi.network import Network, NetworkError +from statemachine.exceptions import TransitionNotAllowed + +from yapapi.network import Network, NetworkError, NetworkState if sys.version_info >= (3, 8): from tests.factories.network import NetworkFactory -def test_init(): - ip = "192.168.0.0" - network = Network(mock.Mock(), f"{ip}/24", "0xdeadbeef") - assert network.network_id is None - assert network.owner_ip == "192.168.0.1" - assert network.network_address == ip - assert network.netmask == "255.255.255.0" - - -def test_init_mask(): - ip = "192.168.0.0" - mask = "255.255.0.0" - network = Network(mock.Mock(), ip, "0xcafed00d", mask=mask) - assert network.network_address == ip - assert network.netmask == mask - - -def test_init_duplicate_mask(): - with pytest.raises(NetworkError): - Network(mock.Mock(), "10.0.0.0/16", "0x0d15ea5e", mask="255.255.0.0") - - -@pytest.mark.asyncio @pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -def test_create(): - ip = "192.168.0.0" - owner_id = "0xcafebabe" - network = NetworkFactory(ip=f"{ip}/24", owner_id=owner_id) - assert network.network_id - assert network.owner_ip == "192.168.0.1" - assert network.network_address == ip - assert network.netmask == "255.255.255.0" - assert network.nodes_dict == {"192.168.0.1": owner_id} - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -def test_create_with_owner_ip(): - network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") - assert list(network.nodes_dict.keys()) == ["192.168.0.2"] - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -def test_create_with_owner_ip_outside_network(): - with pytest.raises(NetworkError) as e: - NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.1.1") - - assert "address must belong to the network" in str(e.value) - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node(): - network = NetworkFactory(ip="192.168.0.0/24") - node1 = await network.add_node("1") - assert node1.ip == "192.168.0.2" - node2 = await network.add_node("2") - assert node2.ip == "192.168.0.3" - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node_owner_ip_different(): - network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") - node1 = await network.add_node("1") - assert node1.ip == "192.168.0.1" - node2 = await network.add_node("2") - assert node2.ip == "192.168.0.3" - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node_specific_ip(): - network = NetworkFactory(ip="192.168.0.0/24") - ip = "192.168.0.5" - node = await network.add_node("1", ip) - assert node.ip == ip - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node_ip_collision(): - network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") - with pytest.raises(NetworkError) as e: - await network.add_node("1", "192.168.0.2") - - assert "has already been assigned in this network" in str(e.value) - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node_ip_outside_network(): - network = NetworkFactory(ip="192.168.0.0/24") - with pytest.raises(NetworkError) as e: - await network.add_node("1", "192.168.1.2") - - assert "address must belong to the network" in str(e.value) - - -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+") -async def test_add_node_pool_depleted(): - network = NetworkFactory(ip="192.168.0.0/30") - await network.add_node("1") - with pytest.raises(NetworkError) as e: - await network.add_node("2") - - assert "No more addresses available" in str(e.value) +class TestNetwork: + def test_init(self): + ip = "192.168.0.0" + network = Network(mock.Mock(), f"{ip}/24", "0xdeadbeef") + assert network.owner_ip == "192.168.0.1" + assert network.network_address == ip + assert network.netmask == "255.255.255.0" + + def test_init_mask(self): + ip = "192.168.0.0" + mask = "255.255.0.0" + network = Network(mock.Mock(), ip, "0xcafed00d", mask=mask) + assert network.network_address == ip + assert network.netmask == mask + + def test_init_duplicate_mask(self): + with pytest.raises(NetworkError): + Network(mock.Mock(), "10.0.0.0/16", "0x0d15ea5e", mask="255.255.0.0") + + @pytest.mark.asyncio + def test_create(self): + ip = "192.168.0.0" + owner_id = "0xcafebabe" + network = NetworkFactory(ip=f"{ip}/24", owner_id=owner_id) + assert network.network_id + assert network.owner_ip == "192.168.0.1" + assert network.network_address == ip + assert network.netmask == "255.255.255.0" + assert network.nodes_dict == {"192.168.0.1": owner_id} + assert network.state == NetworkState.ready + network._net_api.create_network.assert_called_with( + network.network_address, network.netmask, network.gateway + ) + + @pytest.mark.asyncio + def test_create_with_owner_ip(self): + network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") + assert list(network.nodes_dict.keys()) == ["192.168.0.2"] + + @pytest.mark.asyncio + def test_create_with_owner_ip_outside_network(self): + with pytest.raises(NetworkError) as e: + NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.1.1") + + assert "address must belong to the network" in str(e.value) + + @pytest.mark.asyncio + async def test_add_node(self): + network = NetworkFactory(ip="192.168.0.0/24") + node1 = await network.add_node("1") + assert node1.ip == "192.168.0.2" + node2 = await network.add_node("2") + assert node2.ip == "192.168.0.3" + + @pytest.mark.asyncio + async def test_add_node_owner_ip_different(self): + network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") + node1 = await network.add_node("1") + assert node1.ip == "192.168.0.1" + node2 = await network.add_node("2") + assert node2.ip == "192.168.0.3" + + @pytest.mark.asyncio + async def test_add_node_specific_ip(self): + network = NetworkFactory(ip="192.168.0.0/24") + ip = "192.168.0.5" + node = await network.add_node("1", ip) + assert node.ip == ip + + @pytest.mark.asyncio + async def test_add_node_ip_collision(self): + network = NetworkFactory(ip="192.168.0.0/24", owner_ip="192.168.0.2") + with pytest.raises(NetworkError) as e: + await network.add_node("1", "192.168.0.2") + + assert "has already been assigned in this network" in str(e.value) + + @pytest.mark.asyncio + async def test_add_node_ip_outside_network(self): + network = NetworkFactory(ip="192.168.0.0/24") + with pytest.raises(NetworkError) as e: + await network.add_node("1", "192.168.1.2") + + assert "address must belong to the network" in str(e.value) + + @pytest.mark.asyncio + async def test_add_node_pool_depleted(self): + network = NetworkFactory(ip="192.168.0.0/30") + await network.add_node("1") + with pytest.raises(NetworkError) as e: + await network.add_node("2") + + assert "No more addresses available" in str(e.value) + + @pytest.mark.asyncio + async def test_id_when_initialized(self): + network = Network(mock.Mock(), f"192.168.0.0/24", "0xdeadbeef") + with pytest.raises(TransitionNotAllowed, match=".*Can't get_id when in initialized.*") as e: + im_gonna_fail = network.network_id + + @pytest.mark.asyncio + async def test_id_when_removed(self): + network = NetworkFactory(ip="192.168.0.0/24") + assert network.network_id + + await network.remove() + + with pytest.raises(TransitionNotAllowed, match=".*Can't get_id when in removed.*") as e: + im_gonna_fail = network.network_id + + @pytest.mark.asyncio + async def test_remove(self): + network = NetworkFactory(ip="192.168.0.0/24") + + await network.remove() + + network._net_api.remove_network.assert_called_once() + + @pytest.mark.asyncio + async def test_remove_when_initialized(self): + network = Network(mock.Mock(), f"192.168.0.0/24", "0xdeadbeef") + with pytest.raises(TransitionNotAllowed, match=".*Can't stop when in initialized.*") as e: + await network.remove() + + @pytest.mark.asyncio + async def test_remove_when_removed(self): + network = NetworkFactory(ip="192.168.0.0/24") + + await network.remove() + + with pytest.raises(TransitionNotAllowed, match=".*Can't stop when in removed.*") as e: + await network.remove() + + @pytest.mark.asyncio + async def test_network_context_manager(self): + network = NetworkFactory(ip="192.168.0.0/24") + assert network.state == NetworkState.ready + + async with network: + pass + + assert network.state == NetworkState.removed diff --git a/yapapi/network.py b/yapapi/network.py index ee2f0f73d..817ebcf9d 100644 --- a/yapapi/network.py +++ b/yapapi/network.py @@ -1,9 +1,15 @@ import asyncio +import logging from dataclasses import dataclass from ipaddress import ip_address, ip_network, IPv4Address, IPv6Address, IPv4Network, IPv6Network +from statemachine import State, StateMachine # type: ignore from typing import Dict, Optional, Union from urllib.parse import urlparse + import yapapi +from ya_net.exceptions import ApiException + +logger = logging.getLogger("yapapi.network") IpAddress = Union[IPv4Address, IPv6Address] IpNetwork = Union[IPv4Network, IPv6Network] @@ -54,6 +60,28 @@ def get_websocket_uri(self, port: int) -> str: return f"{net_api_ws}/net/{self.network.network_id}/tcp/{self.ip}/{port}" +class NetworkState(StateMachine): + """State machine describing the states and lifecycle of a :class:`Network` instance.""" + + # states + initialized = State("initialized", initial=True) + creating = State("creating") + ready = State("ready") + removing = State("removing") + removed = State("removed") + + # state-altering transitions (lifecycle) + create = initialized.to(creating) + start = creating.to(ready) + stop = ready.to(removing) + remove = removing.to(removed) + + # same-state transitions + add_owner_address = creating.to.itself() | ready.to.itself() + add_node = ready.to.itself() + get_id = creating.to.itself() | ready.to.itself() | removing.to.itself() + + class Network: """ Describes a VPN created between the requestor and the provider nodes within Golem Network. @@ -80,6 +108,7 @@ async def create( """ network = cls(net_api, ip, owner_id, owner_ip, mask, gateway) + network._state_machine.create() # create the network in yagna and set the id network._network_id = await net_api.create_network( @@ -89,6 +118,7 @@ async def create( # add requestor's own address to the network await network.add_owner_address(network.owner_ip) + network._state_machine.start() return network def __init__( @@ -121,6 +151,7 @@ def __init__( self._gateway = gateway self._owner_id = owner_id self._owner_ip: IpAddress = ip_address(owner_ip) if owner_ip else self._next_address() + self._state_machine: NetworkState = NetworkState() self._nodes: Dict[str, Node] = dict() """the mapping between a Golem node id and a Node in this VPN.""" @@ -137,11 +168,22 @@ def __str__(self) -> str: nodes: {self.nodes_dict} }}""" + async def __aenter__(self) -> "Network": + return self + + async def __aexit__(self, *exc_info) -> None: + await self.remove() + @property def owner_ip(self) -> str: - """the IP address of the requestor node within the network""" + """The IP address of the requestor node within the network.""" return str(self._owner_ip) + @property + def state(self) -> State: + """Current state in this network's lifecycle.""" + return self._state_machine.current_state + @property def network_address(self) -> str: """The network address of this network, without a netmask.""" @@ -165,8 +207,10 @@ def nodes_dict(self) -> Dict[str, str]: return {str(v.ip): k for k, v in self._nodes.items()} @property - def network_id(self) -> Optional[str]: + def network_id(self) -> str: """The automatically-generated, unique ID of this VPN.""" + self._state_machine.get_id() + assert self._network_id return self._network_id def _ensure_ip_in_network(self, ip: str): @@ -186,8 +230,7 @@ async def add_owner_address(self, ip: str): :param ip: the IP address to assign to the requestor node. """ - assert self.network_id, "Network not initialized correctly" - + self._state_machine.add_owner_address() self._ensure_ip_in_network(ip) async with self._nodes_lock: @@ -202,7 +245,7 @@ async def add_node(self, node_id: str, ip: Optional[str] = None) -> Node: :param node_id: Node ID within the Golem network of this VPN node. :param ip: IP address to assign to this node. """ - assert self.network_id, "Network not initialized correctly" + self._state_machine.add_node() async with self._nodes_lock: if ip: @@ -221,6 +264,18 @@ async def add_node(self, node_id: str, ip: Optional[str] = None) -> Node: return node + async def remove(self) -> None: + """Remove this network, terminating any connections it provides.""" + self._state_machine.stop() + try: + await self._net_api.remove_network(self.network_id) + except ApiException as e: + if e.status == 404: + logger.debug( + "Tried removing a network which doesn't exist. network_id=%s", self.network_id + ) + self._state_machine.remove() + def _next_address(self) -> IpAddress: """Provide the next available IP address within this Network. diff --git a/yapapi/rest/net.py b/yapapi/rest/net.py index 8998e7a86..0ecd1ba85 100644 --- a/yapapi/rest/net.py +++ b/yapapi/rest/net.py @@ -29,6 +29,9 @@ async def create_network( ) return yan_network.id + async def remove_network(self, network_id: str) -> None: + await self._api.remove_network(network_id) + async def add_address(self, network_id: str, ip: str): address = yan.Address(ip) await self._api.add_address(network_id, address)