Skip to content

Commit

Permalink
rf: refactored IProtoSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
penglei0 committed Oct 22, 2024
1 parent 9e24e32 commit 4e43228
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/interfaces/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from containernet.topology import (ITopology)
from protosuites.proto import IProtoSuite
from interfaces.routing import IRoutingStrategy
from interfaces.host import IHost
from testsuites.test import (ITestSuite)


Expand All @@ -23,7 +24,7 @@ def stop(self):
pass

@abstractmethod
def get_hosts(self):
def get_hosts(self) -> list[IHost]:
pass

@abstractmethod
Expand Down
3 changes: 1 addition & 2 deletions src/protosuites/bats/bats_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from interfaces.host import IHost
from tools.cfg_generator import generate_cfg_files
from protosuites.proto import (ProtoConfig, IProtoSuite)
from protosuites.proto_info import IProtoInfo
from var.global_var import g_root_path


class BATSProtocol(IProtoSuite, IProtoInfo):
class BATSProtocol(IProtoSuite):
def __init__(self, config: ProtoConfig):
super().__init__(config)
if self.config.path is None:
Expand Down
3 changes: 1 addition & 2 deletions src/protosuites/cs_protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from interfaces.network import INetwork
from protosuites.proto import (ProtoConfig, IProtoSuite)
from protosuites.proto_info import IProtoInfo


class CSProtocol(IProtoSuite, IProtoInfo):
class CSProtocol(IProtoSuite):
def __init__(self, config: ProtoConfig, client: IProtoSuite, server: IProtoSuite):
super().__init__(config)
self.client = client
Expand Down
14 changes: 7 additions & 7 deletions src/protosuites/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import IntEnum
from dataclasses import dataclass, field
from typing import (Optional, List)
from protosuites.proto_info import IProtoInfo
from var.global_var import g_root_path


Expand All @@ -29,16 +30,15 @@ class ProtoConfig:
port: Optional[int] = field(default=0)
type: Optional[str] = field(default='distributed')
test_name: str = field(default="")
protocols: Optional[List['ProtoConfig']] = field(
default=None) # type: ignore
protocols: Optional[List['ProtoConfig']] = field(default=None)
config_base_path: Optional[str] = field(default=None)


SupportedProto = ['btp', 'brtp', 'brtp_proxy', 'tcp', 'kcp', 'quic']
SupportedBATSProto = ['btp', 'brtp', 'brtp_proxy']


class IProtoSuite(ABC):
class IProtoSuite(IProtoInfo, ABC):
def __init__(self, config: ProtoConfig):
self.is_success = False
self.config = config
Expand Down Expand Up @@ -68,18 +68,18 @@ def get_config(self) -> ProtoConfig:
return self.config

@abstractmethod
def post_run(self, network: 'INetwork'): # type: ignore
def post_run(self, network: 'INetwork') -> bool: # type: ignore
pass

@abstractmethod
def pre_run(self, network: 'INetwork'): # type: ignore
def pre_run(self, network: 'INetwork') -> bool: # type: ignore
pass

@abstractmethod
def run(self, network: 'INetwork'): # type: ignore
def run(self, network: 'INetwork') -> bool: # type: ignore
pass

def start(self, network: 'INetwork') -> bool:
def start(self, network: 'INetwork') -> bool: # type: ignore
self.is_success = self.pre_run(network)
if not self.is_success:
logging.debug("pre_run failed")
Expand Down
20 changes: 10 additions & 10 deletions src/protosuites/std_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from protosuites.proto import (ProtoConfig, IProtoSuite)
from interfaces.network import INetwork
from var.global_var import g_root_path
from .proto_info import IProtoInfo


class StdProtocol(IProtoSuite, IProtoInfo):
class StdProtocol(IProtoSuite):
"""StdProtocol is used to load the protocol which be described by YAML.
"""

Expand Down Expand Up @@ -127,9 +126,7 @@ def __restore_protocol_version(self, network: INetwork):

def __handle_tcp_version_restore(self, network: INetwork):
hosts = network.get_hosts()
if hosts is None:
return False
for host in hosts: # type: ignore
for host in hosts:
default_ver = self.default_version_dict[host.name()]
if default_ver is None:
continue
Expand All @@ -147,12 +144,15 @@ def __handle_tcp_version_setup(self, network: INetwork, version: str):
"TCP version %s is not supported, please check the configuration.", version)
return
hosts = network.get_hosts()
if hosts is None:
return
for host in hosts: # type: ignore
for host in hosts:
# read `tcp_congestion_control` before change
res = host.popen(
f"sysctl net.ipv4.tcp_congestion_control").stdout.read().decode('utf-8')
pf = host.popen(
f"sysctl net.ipv4.tcp_congestion_control")
if pf is None:
logging.error(
"Failed to get the tcp congestion control on %s", host.name())
continue
res = pf.stdout.read().decode('utf-8')
default_version = res.split('=')[-1].strip()
if default_version == version:
continue
Expand Down
19 changes: 9 additions & 10 deletions src/testsuites/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class TestResult:
record: The record file generated by the test which contains the test results.
file name pattern: <class_name>_<test_type>_<client_host>_<server_host>.log
"""
is_success: bool
pattern: str
record: str
is_success: bool = field(default=False)
pattern: str = field(default="")
record: str = field(default="")
result_dir: str = field(default=f"{g_root_path}")


Expand All @@ -90,15 +90,15 @@ def __init__(self, config: TestConfig) -> None:
self.result = TestResult(False, pattern="", record="")

@abstractmethod
def post_process(self):
def post_process(self) -> bool:
pass

@abstractmethod
def pre_process(self):
def pre_process(self) -> bool:
pass

@abstractmethod
def _run_test(self, network: 'INetwork', proto_info: IProtoInfo): # type: ignore
def _run_test(self, network: 'INetwork', proto_info: IProtoInfo) -> bool: # type: ignore
pass

def run(self, network: 'INetwork', proto_info: IProtoInfo) -> TestResult: # type: ignore
Expand All @@ -114,7 +114,7 @@ def run(self, network: 'INetwork', proto_info: IProtoInfo) -> TestResult: # typ
base_name = proto_info.get_protocol_name().upper()
self.result.record = self.result.result_dir + \
base_name + "_" + self.result.pattern
self.result.is_success = self.pre_process() # type: ignore
self.result.is_success = self.pre_process()
# checking for non-distributed protocols
if not proto_info.is_distributed():
if self.config.client_host is None:
Expand All @@ -127,12 +127,11 @@ def run(self, network: 'INetwork', proto_info: IProtoInfo) -> TestResult: # typ
return self.result
if not self.result.is_success:
return self.result
self.result.is_success = self._run_test(
network, proto_info) # type: ignore
self.result.is_success = self._run_test(network, proto_info)
if not self.result.is_success:
logging.error("ITestSuite %s failed.", self.config.name)
return self.result
self.result.is_success = self.post_process() # type: ignore
self.result.is_success = self.post_process()
if not self.result.is_success:
return self.result
return self.result
Expand Down

0 comments on commit 4e43228

Please sign in to comment.