diff --git a/tests/conftest.py b/tests/conftest.py index 65242ca..80ea65d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -186,6 +186,34 @@ def dump(self): for n in self.neighbours: print(f" {n}") +class MCTPCommand: + def __init__(self): + self.send_channel, self.receive_channel = trio.open_memory_channel(0) + + async def complete(self, data): + async with self.send_channel as chan: + await chan.send(data) + + async def wait(self): + async with self.receive_channel as chan: + return await chan.receive() + +class MCTPControlCommand(MCTPCommand): + MSGTYPE = 0 + + def __init__(self, rq, iid, cmd, data = bytes()): + super().__init__() + self.rq = rq + self.iid = iid + self.cmd = cmd + self.data = data + + def to_buf(self): + flags = self.iid + if self.rq: + flags = flags | 0x80 + return bytes([flags, self.cmd]) + self.data + class Endpoint: def __init__(self, iface, lladdr, ep_uuid = None, eid = 0, types = None): self.iface = iface @@ -193,6 +221,8 @@ def __init__(self, iface, lladdr, ep_uuid = None, eid = 0, types = None): self.uuid = ep_uuid or uuid.uuid1() self.eid = eid self.types = types or [0] + # keyed by (type, type-specific-instance) + self.commands = {} def __str__(self): return f"uuid {self.uuid} lladdr {self.lladdr}, eid {self.eid}" @@ -208,41 +238,72 @@ async def handle_mctp_message(self, sock, addr, data): async def handle_mctp_control(self, sock, addr, data): flags, opcode = data[0:2] + rq = flags & 0x80 + iid = flags & 0x1f - # we're only a responder, ensure we have the Rq bit set - assert flags & 0x80 + if not rq: + cmd = self.commands.get((0, iid), None) + assert cmd is not None, "unexpected response?" - raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) - hdr = [0x00, opcode] - if opcode == 1: - # Set Endpoint ID - (op, eid) = data[2:] - self.eid = eid - data = bytes(hdr + [0x00, 0x00, self.eid, 0x00]) - await sock.send(raddr, data) + await cmd.complete(data) - elif opcode == 2: - # Get Endpoint ID - data = bytes(hdr + [0x00, self.eid, 0x00, 0x00]) - await sock.send(raddr, data) + else: - elif opcode == 3: - # Get Endpoint UUID - data = bytes(hdr + [0x00]) + self.uuid.bytes - await sock.send(raddr, data) + raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) + hdr = [0x00, opcode] + if opcode == 1: + # Set Endpoint ID + (op, eid) = data[2:] + self.eid = eid + data = bytes(hdr + [0x00, 0x00, self.eid, 0x00]) + await sock.send(raddr, data) + + elif opcode == 2: + # Get Endpoint ID + data = bytes(hdr + [0x00, self.eid, 0x00, 0x00]) + await sock.send(raddr, data) + + elif opcode == 3: + # Get Endpoint UUID + data = bytes(hdr + [0x00]) + self.uuid.bytes + await sock.send(raddr, data) + + elif opcode == 5: + # Get Message Type Support + types = self.types + data = bytes(hdr + [0x00, len(types)] + types) + await sock.send(raddr, data) + + else: + await sock.send(raddr, bytes(hdr + [0x05])) # unsupported command + + async def send_control(self, sock, cmd_code, data = bytes()): + + typ = MCTPControlCommand.MSGTYPE + # todo: tag 0 implied + addr = MCTPSockAddr(self.iface.net, self.eid, typ, 0x80) + if sock.addr_ext: + addr.set_ext(self.iface.ifindex, self.lladdr) + + # todo: multiple commands; iid 0 implied + iid = 0 + cmd = MCTPControlCommand(True, iid, cmd_code, data) + + key = (typ, cmd.iid) + assert not key in self.commands + + self.commands[key] = cmd + + await sock.send(addr, cmd.to_buf()) + + return await cmd.wait() - elif opcode == 5: - # Get Message Type Support - types = self.types - data = bytes(hdr + [0x00, len(types)] + types) - await sock.send(raddr, data) - else: - await sock.send(raddr, bytes(hdr + [0x05])) # unsupported command class Network: def __init__(self): self.endpoints = [] + self.mctp_socket = None def add_endpoint(self, endpoint): self.endpoints.append(endpoint) @@ -253,6 +314,12 @@ def lookup_endpoint(self, iface, lladdr): return ep return None + # register the core mctp control socket, on which incoming requests + # are sent to mctpd + def register_mctp_socket(self, socket): + assert self.mctp_socket is None + self.mctp_socket = socket + # MCTP-capable pyroute2 objects class ifinfmsg_mctp(rtnl.ifinfmsg.ifinfmsg): class af_spec(netlink.nla): @@ -474,6 +541,9 @@ async def handle_setsockopt(self, level, optname, optval): val = int.from_bytes(optval, byteorder = sys.byteorder) self.addr_ext = bool(val) + async def handle_bind(self, addr): + self.network.register_mctp_socket(self) + async def send(self, addr, data): addrbuf = addr.to_buf() addrlen = len(addrbuf)