From 473126a8be3514795b22a5845188fcac5d3e14b0 Mon Sep 17 00:00:00 2001
From: Jeremy Kerr <jk@codeconstruct.com.au>
Date: Mon, 5 Feb 2024 14:00:24 +0800
Subject: [PATCH] mctpd: check IIDs and opcodes in control protocol responses

Currently we ignore the IID and opcode fields in control protocol
response messages. This means we may get a delayed response from an
endpoint, it may be interpreted incorrectly as a response to the next
message.

Instead, allocate new IIDs for each message, then check the IIDs in
responses. We do this through a new common response validation function,
mctp_ctrl_validate_response.

This requires some additional string building for error messages.

Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>
---
 CHANGELOG.md            |   3 +
 src/mctp-control-spec.h |   5 +
 src/mctpd.c             | 240 +++++++++++++++++++++++++++++-----------
 tests/conftest.py       |   5 +-
 tests/test_mctpd.py     |  46 +++++++-
 5 files changed, 231 insertions(+), 68 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 077edf8..bda8fbb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
 3. The `tests` option has changed type from `feature` to `boolean`. Tests are
    enabled by default.
 
+4. We now enforce IID checks on MCTP control protocol responses; this
+   prevents odd behaviour from delayed or invalid responses.
+
 ### Fixed
 
 1. mctpd: EID assignments now work in the case where a new endpoint has a
diff --git a/src/mctp-control-spec.h b/src/mctp-control-spec.h
index ea6eb4f..ba29217 100644
--- a/src/mctp-control-spec.h
+++ b/src/mctp-control-spec.h
@@ -13,6 +13,11 @@ struct mctp_ctrl_msg_hdr {
 	uint8_t command_code;
 } __attribute__((__packed__));
 
+struct mctp_ctrl_resp {
+	struct mctp_ctrl_msg_hdr ctrl_hdr;
+	uint8_t completion_code;
+} __attribute__((packed));
+
 typedef enum {
 	set_eid,
 	force_eid,
diff --git a/src/mctpd.c b/src/mctpd.c
index 0437a8e..b47491a 100644
--- a/src/mctpd.c
+++ b/src/mctpd.c
@@ -57,6 +57,7 @@ static size_t MAX_PEER_SIZE = 1000000;
 
 static const uint8_t RQDI_REQ = 1<<7;
 static const uint8_t RQDI_RESP = 0x0;
+static const uint8_t RQDI_IID_MASK = 0x1f;
 
 struct dest_phys {
 	int ifindex;
@@ -155,6 +156,9 @@ struct ctx {
 	// Timeout in usecs for a MCTP response
 	uint64_t mctp_timeout;
 
+	// Next IID to use
+	uint8_t iid;
+
 	uint8_t uuid[16];
 
 	// Verbose logging
@@ -282,6 +286,19 @@ static const char* peer_tostr(const peer *peer)
 	return dfree(str);
 }
 
+static const char* peer_tostr_short(const peer *peer)
+{
+	size_t l = 30;
+	char *str = NULL;
+
+	str = malloc(l);
+	if (!str) {
+		return "Out of memory";
+	}
+	snprintf(str, l, "%d:%d", peer->net, peer->eid);
+	return dfree(str);
+}
+
 static int defer_free_handler(sd_event_source *s, void *userdata)
 {
 	free(userdata);
@@ -933,6 +950,121 @@ static int listen_monitor(ctx *ctx)
 	return rc;
 }
 
+static uint8_t mctp_next_iid(ctx *ctx)
+{
+	uint8_t iid = ctx->iid;
+
+	ctx->iid = (iid + 1) & RQDI_IID_MASK;
+	return iid;
+}
+
+static const char *command_str(uint8_t cmd)
+{
+	static char unknown_cmd_str[32];
+
+	switch (cmd) {
+	case MCTP_CTRL_CMD_SET_ENDPOINT_ID:
+		return "Set Endpoint ID";
+	case MCTP_CTRL_CMD_GET_ENDPOINT_ID:
+		return "Get Endpoint ID";
+	case MCTP_CTRL_CMD_GET_ENDPOINT_UUID:
+		return "Get Endpoint UUID";
+	case MCTP_CTRL_CMD_GET_VERSION_SUPPORT:
+		return "Get Version Support";
+	case MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT:
+		return "Get Message Type Support";
+	case MCTP_CTRL_CMD_GET_VENDOR_MESSAGE_SUPPORT:
+		return "Get Vendor Message Support";
+	case MCTP_CTRL_CMD_RESOLVE_ENDPOINT_ID:
+		return "Resolve Endpoint ID";
+	case MCTP_CTRL_CMD_ALLOCATE_ENDPOINT_IDS:
+		return "Allocate Endpoint ID ";
+	case MCTP_CTRL_CMD_ROUTING_INFO_UPDATE:
+		return "Routing Info Update";
+	case MCTP_CTRL_CMD_GET_ROUTING_TABLE_ENTRIES:
+		return "Get Routing Table Entries";
+	case MCTP_CTRL_CMD_PREPARE_ENDPOINT_DISCOVERY:
+		return "Prepare Endpoint Discovery";
+	case MCTP_CTRL_CMD_ENDPOINT_DISCOVERY:
+		return "Endpoint Discovery";
+	case MCTP_CTRL_CMD_DISCOVERY_NOTIFY:
+		return "Discovery Notify";
+	case MCTP_CTRL_CMD_GET_NETWORK_ID:
+		return "Get Network ID";
+	case MCTP_CTRL_CMD_QUERY_HOP:
+		return "Query Hop";
+	case MCTP_CTRL_CMD_RESOLVE_UUID:
+		return "Resolve UUID";
+	case MCTP_CTRL_CMD_QUERY_RATE_LIMIT:
+		return "Query Rate Limit";
+	case MCTP_CTRL_CMD_REQUEST_TX_RATE_LIMIT:
+		return "Request TX Rate Limit";
+	case MCTP_CTRL_CMD_UPDATE_RATE_LIMIT:
+		return "Update Rate Limit";
+	case MCTP_CTRL_CMD_QUERY_SUPPORTED_INTERFACES:
+		return "Query Supported Interfaces";
+	}
+
+	sprintf(unknown_cmd_str, "Unknown command [0x%02x]", cmd);
+
+	return unknown_cmd_str;
+}
+
+static const char *peer_cmd_prefix(const char *peer, uint8_t cmd)
+{
+	static char pfx_str[64];
+
+	snprintf(pfx_str, sizeof(pfx_str), "[peer %s, cmd %s]",
+		 peer, command_str(cmd));
+
+	return pfx_str;
+}
+
+/* Common checks for responses: that we have enough data for a response,
+ * the expected IID and opcode, and that the response indicated success.
+ */
+static int mctp_ctrl_validate_response(uint8_t *buf, size_t rsp_size, size_t
+				       exp_size, const char *peer, uint8_t iid,
+				       uint8_t cmd)
+{
+	struct mctp_ctrl_resp *rsp;
+
+	if (exp_size <= sizeof(*rsp)) {
+		warnx("invalid expected response size!");
+		return -EINVAL;
+	}
+
+	if (rsp_size < exp_size) {
+		warnx("%s: Wrong reply length (%zu bytes)",
+		      peer_cmd_prefix(peer, cmd), rsp_size);
+		return -ENOMSG;
+	}
+
+	/* we have enough for the smallest common response message */
+	rsp = (void *)buf;
+
+	if ((rsp->ctrl_hdr.rq_dgram_inst & RQDI_IID_MASK) != iid) {
+		warnx("%s: Wrong IID (0x%02x, expected 0x%02x)",
+		      peer_cmd_prefix(peer, cmd),
+		      rsp->ctrl_hdr.rq_dgram_inst & RQDI_IID_MASK, iid);
+		return -ENOMSG;
+	}
+
+	if (rsp->ctrl_hdr.command_code != cmd) {
+		warnx("%s: Wrong opcode (0x%02x) in response",
+		      peer_cmd_prefix(peer, cmd), rsp->ctrl_hdr.command_code);
+		return -ENOMSG;
+	}
+
+	if (rsp->completion_code) {
+		warnx("%s: Command failed, completion code 0x%02x",
+		      peer_cmd_prefix(peer, cmd), rsp->completion_code);
+		return -ECONNREFUSED;
+	}
+
+	return 0;
+}
+
 /* Use endpoint_query_peer() or endpoint_query_phys() instead.
  *
  * resp buffer is allocated, caller to free.
@@ -1094,12 +1226,13 @@ static int endpoint_send_set_endpoint_id(const peer *peer, mctp_eid_t *new_eid)
 	int rc;
 	uint8_t* buf = NULL;
 	size_t buf_size;
-	uint8_t stat, alloc;
+	uint8_t iid, stat, alloc;
 	const dest_phys *dest = &peer->phys;
 
 	rc = -1;
 
-	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ;
+	iid = mctp_next_iid(peer->ctx);
+	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid;
 	req.ctrl_hdr.command_code = MCTP_CTRL_CMD_SET_ENDPOINT_ID;
 	req.operation = 0; // 00b Set EID. TODO: do we want Force?
 	req.eid = peer->eid;
@@ -1108,21 +1241,13 @@ static int endpoint_send_set_endpoint_id(const peer *peer, mctp_eid_t *new_eid)
 	if (rc < 0)
 		goto out;
 
-	if (buf_size != sizeof(*resp)) {
-		warnx("%s: wrong reply length %zu bytes. dest %s", __func__,
-			buf_size, dest_phys_tostr(dest));
-		rc = -ENOMSG;
+	rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp),
+					 dest_phys_tostr(dest),
+					 iid, MCTP_CTRL_CMD_SET_ENDPOINT_ID);
+	if (rc)
 		goto out;
-	}
-	resp = (void*)buf;
 
-	if (resp->completion_code != 0) {
-		// TODO: make this a debug message?
-		warnx("Failure completion code 0x%02x from %s",
-			resp->completion_code, dest_phys_tostr(dest));
-		rc = -ECONNREFUSED;
-		goto out;
-	}
+	resp = (void*)buf;
 
 	stat = resp->status >> 4 & 0x3;
 	if (stat == 0x01) {
@@ -1479,29 +1604,25 @@ static int query_get_endpoint_id(ctx *ctx, const dest_phys *dest,
 	struct mctp_ctrl_resp_get_eid *resp = NULL;
 	uint8_t *buf = NULL;
 	size_t buf_size;
+	uint8_t iid;
 	int rc;
 
-	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ;
+	iid = mctp_next_iid(ctx);
+
+	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid;
 	req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_ID;
 	rc = endpoint_query_phys(ctx, dest, MCTP_CTRL_HDR_MSG_TYPE, &req,
 		sizeof(req), &buf, &buf_size, &addr);
 	if (rc < 0)
 		goto out;
 
-	if (buf_size != sizeof(*resp)) {
-		warnx("%s: wrong reply length %zu bytes. dest %s", __func__, buf_size,
-			dest_phys_tostr(dest));
-		rc = -ENOMSG;
+	rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp),
+					 dest_phys_tostr(dest),
+					 iid, MCTP_CTRL_CMD_GET_ENDPOINT_ID);
+	if (rc)
 		goto out;
-	}
-	resp = (void*)buf;
 
-	if (resp->completion_code != 0) {
-		warnx("Failure completion code 0x%02x from %s",
-			resp->completion_code, dest_phys_tostr(dest));
-		rc = -ECONNREFUSED;
-		goto out;
-	}
+	resp = (void *)buf;
 
 	*ret_eid = resp->eid;
 	*ret_ep_type = resp->eid_type;
@@ -1526,7 +1647,7 @@ static int get_endpoint_peer(ctx *ctx, sd_bus_error *berr,
 
 	*ret_peer = NULL;
 	rc = query_get_endpoint_id(ctx, dest, &eid, &ep_type, &medium_spec);
-	if (rc < 0)
+	if (rc)
 		return rc;
 
 	if (ret_cur_eid)
@@ -1582,13 +1703,15 @@ static int query_get_peer_msgtypes(peer *peer) {
 	struct mctp_ctrl_resp_get_msg_type_support *resp = NULL;
 	uint8_t* buf = NULL;
 	size_t buf_size, expect_size;
+	uint8_t iid;
 	int rc;
 
 	peer->num_message_types = 0;
 	free(peer->message_types);
 	peer->message_types = NULL;
+	iid = mctp_next_iid(peer->ctx);
 
-	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ;
+	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid;
 	req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT;
 
 	rc = endpoint_query_peer(peer, MCTP_CTRL_HDR_MSG_TYPE,
@@ -1596,12 +1719,12 @@ static int query_get_peer_msgtypes(peer *peer) {
 	if (rc < 0)
 		goto out;
 
-	if (buf_size < sizeof(*resp)) {
-		warnx("%s: short reply %zu bytes. dest %s", __func__, buf_size,
-			peer_tostr(peer));
-		rc = -ENOMSG;
+	rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp),
+					 peer_tostr_short(peer), iid,
+					 MCTP_CTRL_CMD_GET_MESSAGE_TYPE_SUPPORT);
+	if (rc)
 		goto out;
-	}
+
 	resp = (void*)buf;
 	expect_size = sizeof(*resp) + resp->msg_type_count;
 	if (buf_size != expect_size) {
@@ -1612,12 +1735,6 @@ static int query_get_peer_msgtypes(peer *peer) {
 		goto out;
 	}
 
-	if (resp->completion_code != 0x00) {
-		rc = -ECONNREFUSED;
-		goto out;
-	}
-
-	peer->num_message_types = resp->msg_type_count;
 	peer->message_types = malloc(resp->msg_type_count);
 	if (!peer->message_types) {
 		rc = -ENOMEM;
@@ -1649,9 +1766,11 @@ query_get_peer_uuid_by_phys(ctx *ctx, const dest_phys *dest, uint8_t uuid[16])
 	struct mctp_ctrl_resp_get_uuid *resp = NULL;
 	uint8_t* buf = NULL;
 	size_t buf_size;
+	uint8_t iid;
 	int rc;
 
-	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ;
+	iid = mctp_next_iid(ctx);
+	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid;
 	req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_UUID;
 
 	rc = endpoint_query_phys(ctx, dest, MCTP_CTRL_HDR_MSG_TYPE,
@@ -1659,21 +1778,13 @@ query_get_peer_uuid_by_phys(ctx *ctx, const dest_phys *dest, uint8_t uuid[16])
 	if (rc < 0)
 		goto out;
 
-	if (buf_size != sizeof(*resp)) {
-		warnx("%s: wrong reply %zu bytes. dest %s", __func__, buf_size,
-			dest_phys_tostr(dest));
-		rc = -ENOMSG;
+	rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp),
+					 dest_phys_tostr(dest),
+					 iid, MCTP_CTRL_CMD_GET_ENDPOINT_UUID);
+	if (rc)
 		goto out;
-	}
-	resp = (void*)buf;
-
-	if (resp->completion_code != 0x00) {
-		warnx("Failure completion code 0x%02x from %s",
-			resp->completion_code, dest_phys_tostr(dest));
-		rc = -ECONNREFUSED;
-		goto out;
-	}
 
+	resp = (void*)buf;
 	memcpy(uuid, resp->uuid, 16);
 
 out:
@@ -1687,6 +1798,7 @@ static int query_get_peer_uuid(peer *peer) {
 	struct mctp_ctrl_resp_get_uuid *resp = NULL;
 	uint8_t* buf = NULL;
 	size_t buf_size;
+	uint8_t iid;
 	int rc;
 
 	if (peer->state != REMOTE) {
@@ -1694,7 +1806,8 @@ static int query_get_peer_uuid(peer *peer) {
 		return -EPROTO;
 	}
 
-	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ;
+	iid = mctp_next_iid(peer->ctx);
+	req.ctrl_hdr.rq_dgram_inst = RQDI_REQ | iid;
 	req.ctrl_hdr.command_code = MCTP_CTRL_CMD_GET_ENDPOINT_UUID;
 
 	rc = endpoint_query_peer(peer, MCTP_CTRL_HDR_MSG_TYPE,
@@ -1702,18 +1815,13 @@ static int query_get_peer_uuid(peer *peer) {
 	if (rc < 0)
 		goto out;
 
-	if (buf_size != sizeof(*resp)) {
-		warnx("%s: wrong reply %zu bytes. dest %s", __func__, buf_size,
-			peer_tostr(peer));
-		rc = -ENOMSG;
+	rc = mctp_ctrl_validate_response(buf, buf_size, sizeof(*resp),
+					 peer_tostr_short(peer),
+					 iid, MCTP_CTRL_CMD_GET_ENDPOINT_UUID);
+	if (rc)
 		goto out;
-	}
-	resp = (void*)buf;
 
-	if (resp->completion_code != 0x00) {
-		rc = -ECONNREFUSED;
-		goto out;
-	}
+	resp = (void*)buf;
 
 	rc = peer_set_uuid(peer, resp->uuid);
 	if (rc < 0)
diff --git a/tests/conftest.py b/tests/conftest.py
index 80ea65d..6d8332c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -236,6 +236,7 @@ async def handle_mctp_message(self, sock, addr, data):
         else:
             print(f"unknown MCTP message type {a.type}")
 
+
     async def handle_mctp_control(self, sock, addr, data):
         flags, opcode = data[0:2]
         rq = flags & 0x80
@@ -250,7 +251,9 @@ async def handle_mctp_control(self, sock, addr, data):
         else:
 
             raddr = MCTPSockAddr.for_ep_resp(self, addr, sock.addr_ext)
-            hdr = [0x00, opcode]
+            # Use IID from request, zero Rq and D bits
+            hdr = [iid, opcode]
+
             if opcode == 1:
                 # Set Endpoint ID
                 (op, eid) = data[2:]
diff --git a/tests/test_mctpd.py b/tests/test_mctpd.py
index ae27938..1ff49f3 100644
--- a/tests/test_mctpd.py
+++ b/tests/test_mctpd.py
@@ -4,7 +4,7 @@
 import asyncdbus
 
 from mctp_test_utils import mctpd_mctp_obj, mctpd_mctp_endpoint_obj
-from conftest import Endpoint
+from conftest import Endpoint, MCTPSockAddr
 
 # DBus constant symbol suffixes:
 #
@@ -390,3 +390,47 @@ async def test_get_endpoint_id(dbus, mctpd):
     assert rsp[2] == 0x00
     # EID matches the system
     assert rsp[3] == mctpd.system.addresses[0].eid
+
+""" During a LearnEndpoint's Get Endpoint ID exchange, return a response
+from a different command; in this case Get Message Type Support, which happens
+to be the same length as a the expected Get Endpoint ID response."""
+async def test_learn_endpoint_invalid_response_command(dbus, mctpd):
+    class BusyEndpoint(Endpoint):
+        async def handle_mctp_control(self, sock, src_addr, msg):
+            flags, opcode = msg[0:2]
+            if opcode != 2:
+                return await super().handle_mctp_control(sock, src_addr, msg)
+            dst_addr = MCTPSockAddr.for_ep_resp(self, src_addr, sock.addr_ext)
+            msg = bytes([flags & 0x1f, 0x05, 0x00, 0x02, 0x00, 0x01])
+            await sock.send(dst_addr, msg)
+
+    iface = mctpd.system.interfaces[0]
+    ep = BusyEndpoint(iface, bytes([0x1e]), eid = 15)
+    mctpd.network.add_endpoint(ep)
+    mctp = await mctpd_mctp_obj(dbus)
+
+    with pytest.raises(asyncdbus.errors.DBusError) as ex:
+        rc = await mctp.call_learn_endpoint(iface.name, ep.lladdr)
+
+    assert str(ex.value) == "Request failed"
+
+""" Ensure a response with an invalid IID is discarded """
+async def test_learn_endpoint_invalid_response_iid(dbus, mctpd):
+    class InvalidIIDEndpoint(Endpoint):
+        async def handle_mctp_control(self, sock, src_addr, msg):
+            # bump IID
+            flags = msg[0]
+            iid_mask = 0x1d
+            flags = (flags & ~iid_mask) | ((flags + 1) & iid_mask)
+            msg = bytes([flags]) + msg[1:]
+            return await super().handle_mctp_control(sock, src_addr, msg)
+
+    iface = mctpd.system.interfaces[0]
+    ep = InvalidIIDEndpoint(iface, bytes([0x1e]), eid = 15)
+    mctpd.network.add_endpoint(ep)
+    mctp = await mctpd_mctp_obj(dbus)
+
+    with pytest.raises(asyncdbus.errors.DBusError) as ex:
+        await mctp.call_learn_endpoint(iface.name, ep.lladdr)
+
+    assert str(ex.value) == "Request failed"