From ee9b1af5949ddceb273b3de3aa4b7f433e746943 Mon Sep 17 00:00:00 2001 From: Jeremy Kerr Date: Mon, 11 Dec 2023 10:21:07 +0800 Subject: [PATCH] mctp: introduce ops wrapper for socket operations We'll want to be able to mock the kernel interactions for testing, so call all of the socket operations (on both netlink and MCTP sockets) through a central set of funtion pointers. This will allow us to override with mock implementations for testing. Signed-off-by: Jeremy Kerr --- meson.build | 5 ++-- src/mctp-netlink.c | 45 +++++++++++++++++------------ src/mctp-ops.c | 72 ++++++++++++++++++++++++++++++++++++++++++++++ src/mctp-ops.h | 33 +++++++++++++++++++++ src/mctp.c | 3 ++ src/mctpd.c | 28 +++++++++++------- 6 files changed, 155 insertions(+), 31 deletions(-) create mode 100644 src/mctp-ops.c create mode 100644 src/mctp-ops.h diff --git a/meson.build b/meson.build index d52f2a5..8b19d25 100644 --- a/meson.build +++ b/meson.build @@ -35,9 +35,10 @@ config_h = configure_file( util_sources = ['src/mctp-util.c'] netlink_sources = ['src/mctp-netlink.c'] +ops_sources = ['src/mctp-ops.c'] executable('mctp', - sources: ['src/mctp.c'] + netlink_sources + util_sources, + sources: ['src/mctp.c'] + netlink_sources + util_sources + ops_sources, install: true, ) @@ -53,7 +54,7 @@ if libsystemd.found() executable('mctpd', sources: [ 'src/mctpd.c', - ] + netlink_sources + util_sources, + ] + netlink_sources + util_sources + ops_sources, dependencies: libsystemd, install: true, install_dir: get_option('sbindir'), diff --git a/src/mctp-netlink.c b/src/mctp-netlink.c index 69f3bee..56f620d 100644 --- a/src/mctp-netlink.c +++ b/src/mctp-netlink.c @@ -18,6 +18,7 @@ #include "mctp-netlink.h" #include "mctp.h" #include "mctp-util.h" +#include "mctp-ops.h" struct linkmap_entry { int ifindex; @@ -65,26 +66,27 @@ static int open_nl_socket(void) struct sockaddr_nl addr; int opt, rc, sd = -1; - rc = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + rc = mctp_ops.nl.socket(); if (rc < 0) goto err; sd = rc; memset(&addr, 0, sizeof(addr)); addr.nl_family = AF_NETLINK; - rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.nl.bind(sd, (struct sockaddr *)&addr, sizeof(addr)); if (rc) goto err; opt = 1; - rc = setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK, - &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK, + &opt, sizeof(opt)); if (rc) { rc = -errno; goto err; } opt = 1; - rc = setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, + sizeof(opt)); if (rc) { rc = -errno; @@ -144,8 +146,8 @@ static void free_linkmap(struct linkmap_entry *linkmap, size_t count) void mctp_nl_close(mctp_nl *nl) { free_linkmap(nl->linkmap, nl->linkmap_count); - close(nl->sd); - close(nl->sd_monitor); + mctp_ops.nl.close(nl->sd); + mctp_ops.nl.close(nl->sd_monitor); free(nl); } @@ -164,14 +166,18 @@ int mctp_nl_monitor(mctp_nl *nl, bool enable) return nl->sd_monitor; opt = RTNLGRP_LINK; - rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, + &opt, sizeof(opt)); if (rc < 0) { rc = -errno; goto err; } opt = RTNLGRP_MCTP_IFADDR; - rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, + &opt, sizeof(opt)); if (rc < 0) { rc = -errno; if (errno == EINVAL) { @@ -492,7 +498,7 @@ static int handle_nlmsg_ack(mctp_nl *nl) int rc; size_t len; - rc = recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL); + rc = mctp_ops.nl.recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL); if (rc < 0) return rc; len = rc; @@ -535,8 +541,8 @@ int mctp_nl_send(mctp_nl *nl, struct nlmsghdr *msg) addr.nl_family = AF_NETLINK; addr.nl_pid = 0; - rc = sendto(nl->sd, msg, msg->nlmsg_len, 0, - (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.nl.sendto(nl->sd, msg, msg->nlmsg_len, 0, + (struct sockaddr *)&addr, sizeof(addr)); if (rc < 0) return rc; @@ -585,7 +591,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd, // read all the responses into a single buffer while (!done) { - rc = recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC, NULL, 0); + rc = mctp_ops.nl.recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC, + NULL, 0); if (rc < 0) { warnx("recvfrom(MSG_PEEK)"); rc = -errno; @@ -614,8 +621,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd, resp = respbuf + pos; addrlen = sizeof(addr); - rc = recvfrom(sd, resp, readlen, MSG_TRUNC, - (struct sockaddr *)&addr, &addrlen); + rc = mctp_ops.nl.recvfrom(sd, resp, readlen, MSG_TRUNC, + (struct sockaddr *)&addr, &addrlen); if (rc < 0) { warnx("recvfrom(MSG_PEEK)"); rc = -errno; @@ -747,8 +754,8 @@ static int fill_linkmap(mctp_nl *nl) addrlen = sizeof(addr); for (;;) { - rc = recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK, - NULL, NULL); + rc = mctp_ops.nl.recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK, + NULL, NULL); if (rc < 0) { warn("recvfrom(MSG_PEEK)"); break; @@ -768,8 +775,8 @@ static int fill_linkmap(mctp_nl *nl) buf = tmp; } - rc = recvfrom(nl->sd, buf, buflen, 0, - (struct sockaddr *)&addr, &addrlen); + rc = mctp_ops.nl.recvfrom(nl->sd, buf, buflen, 0, + (struct sockaddr *)&addr, &addrlen); if (rc < 0) { warn("recvfrom()"); break; diff --git a/src/mctp-ops.c b/src/mctp-ops.c new file mode 100644 index 0000000..bfb491d --- /dev/null +++ b/src/mctp-ops.c @@ -0,0 +1,72 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * mctp-ops: Abstraction for socket operations for mctp & mctpd. + * + * Copyright (c) 2023 Code Construct + */ + +#define _GNU_SOURCE + +#include +#include + +#include "mctp-ops.h" + +static int mctp_op_mctp_socket(void) +{ + return socket(AF_MCTP, SOCK_DGRAM, 0); +} + +static int mctp_op_netlink_socket(void) +{ + return socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); +} + +static int mctp_op_bind(int sd, struct sockaddr *addr, socklen_t addrlen) +{ + return bind(sd, addr, addrlen); +} + +static int mctp_op_setsockopt(int sd, int level, int optname, void *optval, + socklen_t optlen) +{ + return setsockopt(sd, level, optname, optval, optlen); +} + +static ssize_t mctp_op_sendto(int sd, const void *buf, size_t len, int flags, + const struct sockaddr *dest, socklen_t addrlen) +{ + return sendto(sd, buf, len, flags, dest, addrlen); +} + +static ssize_t mctp_op_recvfrom(int sd, void *buf, size_t len, int flags, + struct sockaddr *src, socklen_t *addrlen) +{ + return recvfrom(sd, buf, len, flags, src, addrlen); +} + +static int mctp_op_close(int sd) +{ + return close(sd); +} + +struct mctp_ops mctp_ops = { + .mctp = { + .socket = mctp_op_mctp_socket, + .setsockopt = mctp_op_setsockopt, + .bind = mctp_op_bind, + .sendto = mctp_op_sendto, + .recvfrom = mctp_op_recvfrom, + .close = mctp_op_close, + }, + .nl = { + .socket = mctp_op_netlink_socket, + .setsockopt = mctp_op_setsockopt, + .bind = mctp_op_bind, + .sendto = mctp_op_sendto, + .recvfrom = mctp_op_recvfrom, + .close = mctp_op_close, + }, +}; + +void mctp_ops_init(void) { } diff --git a/src/mctp-ops.h b/src/mctp-ops.h new file mode 100644 index 0000000..c5d4636 --- /dev/null +++ b/src/mctp-ops.h @@ -0,0 +1,33 @@ + +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * mctpd: bus owner for MCTP using Linux kernel + * + * Copyright (c) 2023 Code Construct + */ +#pragma once + +#include + +#define _GNU_SOURCE + +struct socket_ops { + int (*socket)(void); + int (*setsockopt)(int sd, int level, int optname, void *optval, + socklen_t optlen); + int (*bind)(int sd, struct sockaddr *addr, socklen_t addrlen); + ssize_t (*sendto)(int sd, const void *buf, size_t len, int flags, + const struct sockaddr *dest, socklen_t addrlen); + ssize_t (*recvfrom)(int sd, void *buf, size_t len, int flags, + struct sockaddr *src, socklen_t *addrlen); + int (*close)(int sd); +}; + +struct mctp_ops { + struct socket_ops mctp; + struct socket_ops nl; +}; + +extern struct mctp_ops mctp_ops; + +void mctp_ops_init(void); diff --git a/src/mctp.c b/src/mctp.c index 72ab7db..0dd13b3 100644 --- a/src/mctp.c +++ b/src/mctp.c @@ -33,6 +33,7 @@ #include "mctp.h" #include "mctp-util.h" #include "mctp-netlink.h" +#include "mctp-ops.h" #define max(a, b) ((a) > (b) ? (a) : (b)) #define min(a, b) ((a) < (b) ? (a) : (b)) @@ -769,6 +770,8 @@ static int cmd_addr_addremove(struct ctx *ctx, return -1; } + mctp_ops_init(); + eidstr = argv[1]; linkstr = argv[3]; diff --git a/src/mctpd.c b/src/mctpd.c index 000b4fd..9b30e1c 100644 --- a/src/mctpd.c +++ b/src/mctpd.c @@ -32,6 +32,7 @@ #include "mctp-util.h" #include "mctp-netlink.h" #include "mctp-control-spec.h" +#include "mctp-ops.h" #define max(a, b) ((a) > (b) ? (a) : (b)) #define min(a, b) ((a) < (b) ? (a) : (b)) @@ -417,7 +418,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz uint8_t* buf = NULL; size_t buf_size; - len = recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, NULL, 0); + len = mctp_ops.mctp.recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, + NULL, 0); if (len < 0) { rc = -errno; goto out; @@ -432,7 +434,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz addrlen = sizeof(struct sockaddr_mctp_ext); memset(ret_addr, 0x0, addrlen); - len = recvfrom(sd, buf, buf_size, MSG_TRUNC, (struct sockaddr *)ret_addr, + len = mctp_ops.mctp.recvfrom(sd, buf, buf_size, MSG_TRUNC, + (struct sockaddr *)ret_addr, &addrlen); if (len < 0) { rc = -errno; @@ -485,8 +488,9 @@ static int reply_message(ctx *ctx, int sd, const void *resp, size_t resp_len, return -EPROTO; } - len = sendto(sd, resp, resp_len, 0, - (struct sockaddr*)&reply_addr, sizeof(reply_addr)); + len = mctp_ops.mctp.sendto(sd, resp, resp_len, 0, + (struct sockaddr *)&reply_addr, + sizeof(reply_addr)); if (len < 0) { return -errno; } @@ -785,7 +789,7 @@ static int listen_control_msg(ctx *ctx, int net) struct sockaddr_mctp addr = { 0 }; int rc, sd = -1, val; - sd = socket(AF_MCTP, SOCK_DGRAM, 0); + sd = mctp_ops.mctp.socket(); if (sd < 0) { rc = -errno; warn("%s: socket() failed", __func__); @@ -798,7 +802,7 @@ static int listen_control_msg(ctx *ctx, int net) addr.smctp_type = MCTP_CTRL_HDR_MSG_TYPE; addr.smctp_tag = MCTP_TAG_OWNER; - rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.mctp.bind(sd, (struct sockaddr *)&addr, sizeof(addr)); if (rc < 0) { rc = -errno; warn("%s: bind() failed", __func__); @@ -806,7 +810,8 @@ static int listen_control_msg(ctx *ctx, int net) } val = 1; - rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val)); + rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, + &val, sizeof(val)); if (rc < 0) { rc = -errno; warn("Kernel does not support MCTP extended addressing"); @@ -936,7 +941,7 @@ static int endpoint_query_addr(ctx *ctx, *resp = NULL; *resp_len = 0; - sd = socket(AF_MCTP, SOCK_DGRAM, 0); + sd = mctp_ops.mctp.socket(); if (sd < 0) { warn("socket"); rc = -errno; @@ -945,7 +950,8 @@ static int endpoint_query_addr(ctx *ctx, // We want extended addressing on all received messages val = 1; - rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val)); + rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, + &val, sizeof(val)); if (rc < 0) { rc = -errno; warn("Kernel does not support MCTP extended addressing"); @@ -963,7 +969,8 @@ static int endpoint_query_addr(ctx *ctx, rc = -EPROTO; goto out; } - rc = sendto(sd, req, req_len, 0, (struct sockaddr*)req_addr, req_addr_len); + rc = mctp_ops.mctp.sendto(sd, req, req_len, 0, + (struct sockaddr *)req_addr, req_addr_len); if (rc < 0) { rc = -errno; if (ctx->verbose) { @@ -3477,6 +3484,7 @@ int main(int argc, char **argv) setlinebuf(stdout); setup_config(ctx); + mctp_ops_init(); rc = parse_args(ctx, argc, argv); if (rc != 0) {