From 473126a8be3514795b22a5845188fcac5d3e14b0 Mon Sep 17 00:00:00 2001 From: Jeremy Kerr <jk@codeconstruct.com.au> Date: Mon, 5 Feb 2024 14:00:24 +0800 Subject: [PATCH] mctpd: check IIDs and opcodes in control protocol responses Currently we ignore the IID and opcode fields in control protocol response messages. This means we may get a delayed response from an endpoint, it may be interpreted incorrectly as a response to the next message. Instead, allocate new IIDs for each message, then check the IIDs in responses. We do this through a new common response validation function, mctp_ctrl_validate_response. This requires some additional string building for error messages. Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au> --- CHANGELOG.md | 3 + src/mctp-control-spec.h | 5 + src/mctpd.c | 240 +++++++++++++++++++++++++++++----------- tests/conftest.py | 5 +- tests/test_mctpd.py | 46 +++++++- 5 files changed, 231 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 077edf8..bda8fbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 3. The `tests` option has changed type from `feature` to `boolean`. Tests are enabled by default. +4. We now enforce IID checks on MCTP control protocol responses; this + prevents odd behaviour from delayed or invalid responses. + ### Fixed 1. mctpd: EID assignments now work in the case where a new endpoint has a diff --git a/src/mctp-control-spec.h b/src/mctp-control-spec.h index ea6eb4f..ba29217 100644 --- a/src/mctp-control-spec.h +++ b/src/mctp-control-spec.h @@ -13,6 +13,11 @@ struct mctp_ctrl_msg_hdr { uint8_t command_code; } __attribute__((__packed__)); +struct mctp_ctrl_resp { + struct mctp_ctrl_msg_hdr ctrl_hdr; + uint8_t completion_code; +} __attribute__((packed)); + typedef enum { set_eid, force_eid, diff --git a/src/mctpd.c b/src/mctpd.c index 0437a8e..b47491a 100644 --- a/src/mctpd.c +++ b/src/mctpd.c @@ -57,6 +57,7 @@ static size_t MAX_PEER_SIZE = 1000000; static const uint8_t RQDI_REQ = 1<<7; static const uint8_t RQDI_RESP = 0x0; +static const uint8_t RQDI_IID_MASK = 0x1f; struct dest_phys { int ifindex; @@ -155,6 +156,9 @@ struct ctx { // Timeout in usecs for a MCTP response uint64_t mctp_timeout; + // Next IID to use + uint8_t iid; + uint8_t uuid[16]; // Verbose logging @@ -282,6 +286,19 @@ static const char* peer_tostr(const peer *peer) return dfree(str); } +static const char* peer_tostr_short(const peer *peer) +{ + size_t l = 30; + char *str = NULL; + + str = malloc(l); + if (!str) { + return "Out of memory"; + } + snprintf(str, l, "%d:%d", peer->net, peer->eid); + return dfree(str); +} + static int defer_free_handler(sd_event_source *s, void *userdata) { free(userdata); @@ -933,6 +950,121 @@ static int listen_monitor(ctx *ctx) return rc; } +static uint8_t mctp_next_iid(ctx *ctx) +{ + uint8_t iid = ctx->iid; + + ctx->iid = (iid + 1) & RQDI_IID_MASK; + return iid; +} + +static const char *command_str(uint8_t cmd) +{ + static char unknown_cmd_str[32]; + + switch (cmd) { + case MCTP_CTRL_CMD_SET_ENDPOINT_ID: + return "Set Endpoint ID"; + case MCTP_CTRL_CMD_GET_ENDPOINT_ID: + return "Get Endpoint ID"; + case MCTP_CTRL_CMD_GET_ENDPOINT_UUID: + return "Get Endpoint UUID"; + case MCTP_CTRL_CMD_GET_VERSION_SUPPORT: + return "Get Version Support"; + case MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT: + return "Get Message Type Support"; + case MCTP_CTRL_CMD_GET_VENDOR_MESSAGE_SUPPORT: + return "Get Vendor Message Support"; + case MCTP_CTRL_CMD_RESOLVE_ENDPOINT_ID: + return "Resolve Endpoint ID"; + case MCTP_CTRL_CMD_ALLOCATE_ENDPOINT_IDS: + return "Allocate Endpoint ID "; + case MCTP_CTRL_CMD_ROUTING_INFO_UPDATE: + return "Routing Info Update"; + case MCTP_CTRL_CMD_GET_ROUTING_TABLE_ENTRIES: + return "Get Routing Table Entries"; + case MCTP_CTRL_CMD_PREPARE_ENDPOINT_DISCOVERY: + return "Prepare Endpoint Discovery"; + case MCTP_CTRL_CMD_ENDPOINT_DISCOVERY: + return "Endpoint Discovery"; + case MCTP_CTRL_CMD_DISCOVERY_NOTIFY: + return "Discovery Notify"; + case MCTP_CTRL_CMD_GET_NETWORK_ID: + return "Get Network ID"; + case MCTP_CTRL_CMD_QUERY_HOP: + return "Query Hop"; + case MCTP_CTRL_CMD_RESOLVE_UUID: + return "Resolve UUID"; + case MCTP_CTRL_CMD_QUERY_RATE_LIMIT: + return "Query Rate Limit"; + case MCTP_CTRL_CMD_REQUEST_TX_RATE_LIMIT: + return "Request TX Rate Limit"; + case MCTP_CTRL_CMD_UPDATE_RATE_LIMIT: + return "Update Rate Limit"; + case MCTP_CTRL_CMD_QUERY_SUPPORTED_INTERFACES: + return "Query Supported Interfaces"; + } + + sprintf(unknown_cmd_str, "Unknown command [0x%02x]", cmd); + + return unknown_cmd_str; +} + +static const char *peer_cmd_prefix(const char *peer, uint8_t cmd) +{ + static char pfx_str[64]; + + snprintf(pfx_str, sizeof(pfx_str), "[peer %s, cmd %s]", + peer, command_str(cmd)); + + return pfx_str; +} + +/* Common checks for responses: that we have enough data for a response, + * the expected IID and opcode, and that the response indicated success. + */ +static int mctp_ctrl_validate_response(uint8_t *buf, size_t rsp_size, size_t + exp_size, const char *peer, uint8_t iid, + uint8_t cmd) +{ + struct mctp_ctrl_resp *rsp; + + if (exp_size <= sizeof(*rsp)) { + warnx("invalid expected response size!"); + return -EINVAL; + } + + if (rsp_size < exp_size) { + warnx("%s: Wrong reply length (%zu bytes)", + peer_cmd_prefix(peer, cmd), rsp_size); + return -ENOMSG; + } + + /* we have enough for the smallest common response message */ + rsp = (void *)buf; + + if ((rsp->ctrl_hdr.rq_dgram_inst & RQDI_IID_MASK) != iid) { + warnx("%s: Wrong IID (0x%02x, expected 0x%02x)", + peer_cmd_prefix(peer, cmd), + rsp->ctrl_hdr.rq_dgram_inst & RQDI_IID_MASK, iid); + return -ENOMSG; + } + + if (rsp->ctrl_hdr.command_code != cmd) { + warnx("%s: Wrong opcode (0x%02x) in response", + peer_cmd_prefix(peer, cmd), rsp->ctrl_hdr.command_code); + return -ENOMSG; + } + + if (rsp->completion_code) { + warnx("%s: Command failed, completion code 0x%02x", + peer_cmd_prefix(peer, cmd), rsp->completion_code); + return -ECONNREFUSED; + } + + return 0; +} + /* Use endpoint_query_peer() or endpoint_query_phys() instead. * * resp buffer is allocated, caller to free. @@ -1094,12 +1226,13 @@ static int endpoint_send_set_endpoint_id(const peer *peer, mctp_eid_t *new_eid) int rc; uint8_t* buf = NULL; size_t buf_size; - uint8_t stat, alloc; + uint8_t iid, stat, alloc; const dest_phys *dest = &peer->phys; rc = -1; - req.ctrl_hdr.rq_dgram_inst = RQDI_REQ; + iid = mctp_next_iid(peer->ctx); + req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid; req.ctrl_hdr.command_code = MCTP_CTRL_CMD_SET_ENDPOINT_ID; req.operation = 0; // 00b Set EID. TODO: do we want Force? req.eid = peer->eid; @@ -1108,21 +1241,13 @@ static int endpoint_send_set_endpoint_id(const peer *peer, mctp_eid_t *new_eid) if (rc < 0) goto out; - if (buf_size != sizeof(*resp)) { - warnx("%s: wrong reply length %zu bytes. dest %s", __func__, - buf_size, dest_phys_tostr(dest)); - rc = -ENOMSG; + rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp), + dest_phys_tostr(dest), + iid, MCTP_CTRL_CMD_SET_ENDPOINT_ID); + if (rc) goto out; - } - resp = (void*)buf; - if (resp->completion_code != 0) { - // TODO: make this a debug message? - warnx("Failure completion code 0x%02x from %s", - resp->completion_code, dest_phys_tostr(dest)); - rc = -ECONNREFUSED; - goto out; - } + resp = (void*)buf; stat = resp->status >> 4 & 0x3; if (stat == 0x01) { @@ -1479,29 +1604,25 @@ static int query_get_endpoint_id(ctx *ctx, const dest_phys *dest, struct mctp_ctrl_resp_get_eid *resp = NULL; uint8_t *buf = NULL; size_t buf_size; + uint8_t iid; int rc; - req.ctrl_hdr.rq_dgram_inst = RQDI_REQ; + iid = mctp_next_iid(ctx); + + req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid; req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_ID; rc = endpoint_query_phys(ctx, dest, MCTP_CTRL_HDR_MSG_TYPE, &req, sizeof(req), &buf, &buf_size, &addr); if (rc < 0) goto out; - if (buf_size != sizeof(*resp)) { - warnx("%s: wrong reply length %zu bytes. dest %s", __func__, buf_size, - dest_phys_tostr(dest)); - rc = -ENOMSG; + rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp), + dest_phys_tostr(dest), + iid, MCTP_CTRL_CMD_GET_ENDPOINT_ID); + if (rc) goto out; - } - resp = (void*)buf; - if (resp->completion_code != 0) { - warnx("Failure completion code 0x%02x from %s", - resp->completion_code, dest_phys_tostr(dest)); - rc = -ECONNREFUSED; - goto out; - } + resp = (void *)buf; *ret_eid = resp->eid; *ret_ep_type = resp->eid_type; @@ -1526,7 +1647,7 @@ static int get_endpoint_peer(ctx *ctx, sd_bus_error *berr, *ret_peer = NULL; rc = query_get_endpoint_id(ctx, dest, &eid, &ep_type, &medium_spec); - if (rc < 0) + if (rc) return rc; if (ret_cur_eid) @@ -1582,13 +1703,15 @@ static int query_get_peer_msgtypes(peer *peer) { struct mctp_ctrl_resp_get_msg_type_support *resp = NULL; uint8_t* buf = NULL; size_t buf_size, expect_size; + uint8_t iid; int rc; peer->num_message_types = 0; free(peer->message_types); peer->message_types = NULL; + iid = mctp_next_iid(peer->ctx); - req.ctrl_hdr.rq_dgram_inst = RQDI_REQ; + req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid; req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT; rc = endpoint_query_peer(peer, MCTP_CTRL_HDR_MSG_TYPE, @@ -1596,12 +1719,12 @@ static int query_get_peer_msgtypes(peer *peer) { if (rc < 0) goto out; - if (buf_size < sizeof(*resp)) { - warnx("%s: short reply %zu bytes. dest %s", __func__, buf_size, - peer_tostr(peer)); - rc = -ENOMSG; + rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp), + peer_tostr_short(peer), iid, + MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT); + if (rc) goto out; - } + resp = (void*)buf; expect_size = sizeof(*resp) + resp->msg_type_count; if (buf_size != expect_size) { @@ -1612,12 +1735,6 @@ static int query_get_peer_msgtypes(peer *peer) { goto out; } - if (resp->completion_code != 0x00) { - rc = -ECONNREFUSED; - goto out; - } - - peer->num_message_types = resp->msg_type_count; peer->message_types = malloc(resp->msg_type_count); if (!peer->message_types) { rc = -ENOMEM; @@ -1649,9 +1766,11 @@ query_get_peer_uuid_by_phys(ctx *ctx, const dest_phys *dest, uint8_t uuid[16]) struct mctp_ctrl_resp_get_uuid *resp = NULL; uint8_t* buf = NULL; size_t buf_size; + uint8_t iid; int rc; - req.ctrl_hdr.rq_dgram_inst = RQDI_REQ; + iid = mctp_next_iid(ctx); + req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid; req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_UUID; rc = endpoint_query_phys(ctx, dest, MCTP_CTRL_HDR_MSG_TYPE, @@ -1659,21 +1778,13 @@ query_get_peer_uuid_by_phys(ctx *ctx, const dest_phys *dest, uint8_t uuid[16]) if (rc < 0) goto out; - if (buf_size != sizeof(*resp)) { - warnx("%s: wrong reply %zu bytes. dest %s", __func__, buf_size, - dest_phys_tostr(dest)); - rc = -ENOMSG; + rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp), + dest_phys_tostr(dest), + iid, MCTP_CTRL_CMD_GET_ENDPOINT_UUID); + if (rc) goto out; - } - resp = (void*)buf; - - if (resp->completion_code != 0x00) { - warnx("Failure completion code 0x%02x from %s", - resp->completion_code, dest_phys_tostr(dest)); - rc = -ECONNREFUSED; - goto out; - } + resp = (void*)buf; memcpy(uuid, resp->uuid, 16); out: @@ -1687,6 +1798,7 @@ static int query_get_peer_uuid(peer *peer) { struct mctp_ctrl_resp_get_uuid *resp = NULL; uint8_t* buf = NULL; size_t buf_size; + uint8_t iid; int rc; if (peer->state != REMOTE) { @@ -1694,7 +1806,8 @@ static int query_get_peer_uuid(peer *peer) { return -EPROTO; } - req.ctrl_hdr.rq_dgram_inst = RQDI_REQ; + iid = mctp_next_iid(peer->ctx); + req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid; req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_UUID; rc = endpoint_query_peer(peer, MCTP_CTRL_HDR_MSG_TYPE, @@ -1702,18 +1815,13 @@ static int query_get_peer_uuid(peer *peer) { if (rc < 0) goto out; - if (buf_size != sizeof(*resp)) { - warnx("%s: wrong reply %zu bytes. dest %s", __func__, buf_size, - peer_tostr(peer)); - rc = -ENOMSG; + rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp), + peer_tostr_short(peer), + iid, MCTP_CTRL_CMD_GET_ENDPOINT_UUID); + if (rc) goto out; - } - resp = (void*)buf; - if (resp->completion_code != 0x00) { - rc = -ECONNREFUSED; - goto out; - } + resp = (void*)buf; rc = peer_set_uuid(peer, resp->uuid); if (rc < 0) diff --git a/tests/conftest.py b/tests/conftest.py index 80ea65d..6d8332c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -236,6 +236,7 @@ async def handle_mctp_message(self, sock, addr, data): else: print(f"unknown MCTP message type {a.type}") + async def handle_mctp_control(self, sock, addr, data): flags, opcode = data[0:2] rq = flags & 0x80 @@ -250,7 +251,9 @@ async def handle_mctp_control(self, sock, addr, data): else: raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext) - hdr = [0x00, opcode] + # Use IID from request, zero Rq and D bits + hdr = [iid, opcode] + if opcode == 1: # Set Endpoint ID (op, eid) = data[2:] diff --git a/tests/test_mctpd.py b/tests/test_mctpd.py index ae27938..1ff49f3 100644 --- a/tests/test_mctpd.py +++ b/tests/test_mctpd.py @@ -4,7 +4,7 @@ import asyncdbus from mctp_test_utils import mctpd_mctp_obj, mctpd_mctp_endpoint_obj -from conftest import Endpoint +from conftest import Endpoint, MCTPSockAddr # DBus constant symbol suffixes: # @@ -390,3 +390,47 @@ async def test_get_endpoint_id(dbus, mctpd): assert rsp[2] == 0x00 # EID matches the system assert rsp[3] == mctpd.system.addresses[0].eid + +""" During a LearnEndpoint's Get Endpoint ID exchange, return a response +from a different command; in this case Get Message Type Support, which happens +to be the same length as a the expected Get Endpoint ID response.""" +async def test_learn_endpoint_invalid_response_command(dbus, mctpd): + class BusyEndpoint(Endpoint): + async def handle_mctp_control(self, sock, src_addr, msg): + flags, opcode = msg[0:2] + if opcode != 2: + return await super().handle_mctp_control(sock, src_addr, msg) + dst_addr = MCTPSockAddr.for_ep_resp(self, src_addr, sock.addr_ext) + msg = bytes([flags & 0x1f, 0x05, 0x00, 0x02, 0x00, 0x01]) + await sock.send(dst_addr, msg) + + iface = mctpd.system.interfaces[0] + ep = BusyEndpoint(iface, bytes([0x1e]), eid = 15) + mctpd.network.add_endpoint(ep) + mctp = await mctpd_mctp_obj(dbus) + + with pytest.raises(asyncdbus.errors.DBusError) as ex: + rc = await mctp.call_learn_endpoint(iface.name, ep.lladdr) + + assert str(ex.value) == "Request failed" + +""" Ensure a response with an invalid IID is discarded """ +async def test_learn_endpoint_invalid_response_iid(dbus, mctpd): + class InvalidIIDEndpoint(Endpoint): + async def handle_mctp_control(self, sock, src_addr, msg): + # bump IID + flags = msg[0] + iid_mask = 0x1d + flags = (flags & ~iid_mask) | ((flags + 1) & iid_mask) + msg = bytes([flags]) + msg[1:] + return await super().handle_mctp_control(sock, src_addr, msg) + + iface = mctpd.system.interfaces[0] + ep = InvalidIIDEndpoint(iface, bytes([0x1e]), eid = 15) + mctpd.network.add_endpoint(ep) + mctp = await mctpd_mctp_obj(dbus) + + with pytest.raises(asyncdbus.errors.DBusError) as ex: + await mctp.call_learn_endpoint(iface.name, ep.lladdr) + + assert str(ex.value) == "Request failed"