diff --git a/meson.build b/meson.build index d52f2a5..8b19d25 100644 --- a/meson.build +++ b/meson.build @@ -35,9 +35,10 @@ config_h = configure_file( util_sources = ['src/mctp-util.c'] netlink_sources = ['src/mctp-netlink.c'] +ops_sources = ['src/mctp-ops.c'] executable('mctp', - sources: ['src/mctp.c'] + netlink_sources + util_sources, + sources: ['src/mctp.c'] + netlink_sources + util_sources + ops_sources, install: true, ) @@ -53,7 +54,7 @@ if libsystemd.found() executable('mctpd', sources: [ 'src/mctpd.c', - ] + netlink_sources + util_sources, + ] + netlink_sources + util_sources + ops_sources, dependencies: libsystemd, install: true, install_dir: get_option('sbindir'), diff --git a/src/mctp-netlink.c b/src/mctp-netlink.c index 69f3bee..56f620d 100644 --- a/src/mctp-netlink.c +++ b/src/mctp-netlink.c @@ -18,6 +18,7 @@ #include "mctp-netlink.h" #include "mctp.h" #include "mctp-util.h" +#include "mctp-ops.h" struct linkmap_entry { int ifindex; @@ -65,26 +66,27 @@ static int open_nl_socket(void) struct sockaddr_nl addr; int opt, rc, sd = -1; - rc = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + rc = mctp_ops.nl.socket(); if (rc < 0) goto err; sd = rc; memset(&addr, 0, sizeof(addr)); addr.nl_family = AF_NETLINK; - rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.nl.bind(sd, (struct sockaddr *)&addr, sizeof(addr)); if (rc) goto err; opt = 1; - rc = setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK, - &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK, + &opt, sizeof(opt)); if (rc) { rc = -errno; goto err; } opt = 1; - rc = setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, + sizeof(opt)); if (rc) { rc = -errno; @@ -144,8 +146,8 @@ static void free_linkmap(struct linkmap_entry *linkmap, size_t count) void mctp_nl_close(mctp_nl *nl) { free_linkmap(nl->linkmap, nl->linkmap_count); - close(nl->sd); - close(nl->sd_monitor); + mctp_ops.nl.close(nl->sd); + mctp_ops.nl.close(nl->sd_monitor); free(nl); } @@ -164,14 +166,18 @@ int mctp_nl_monitor(mctp_nl *nl, bool enable) return nl->sd_monitor; opt = RTNLGRP_LINK; - rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, + &opt, sizeof(opt)); if (rc < 0) { rc = -errno; goto err; } opt = RTNLGRP_MCTP_IFADDR; - rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt)); + rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, + &opt, sizeof(opt)); if (rc < 0) { rc = -errno; if (errno == EINVAL) { @@ -492,7 +498,7 @@ static int handle_nlmsg_ack(mctp_nl *nl) int rc; size_t len; - rc = recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL); + rc = mctp_ops.nl.recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL); if (rc < 0) return rc; len = rc; @@ -535,8 +541,8 @@ int mctp_nl_send(mctp_nl *nl, struct nlmsghdr *msg) addr.nl_family = AF_NETLINK; addr.nl_pid = 0; - rc = sendto(nl->sd, msg, msg->nlmsg_len, 0, - (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.nl.sendto(nl->sd, msg, msg->nlmsg_len, 0, + (struct sockaddr *)&addr, sizeof(addr)); if (rc < 0) return rc; @@ -585,7 +591,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd, // read all the responses into a single buffer while (!done) { - rc = recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC, NULL, 0); + rc = mctp_ops.nl.recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC, + NULL, 0); if (rc < 0) { warnx("recvfrom(MSG_PEEK)"); rc = -errno; @@ -614,8 +621,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd, resp = respbuf + pos; addrlen = sizeof(addr); - rc = recvfrom(sd, resp, readlen, MSG_TRUNC, - (struct sockaddr *)&addr, &addrlen); + rc = mctp_ops.nl.recvfrom(sd, resp, readlen, MSG_TRUNC, + (struct sockaddr *)&addr, &addrlen); if (rc < 0) { warnx("recvfrom(MSG_PEEK)"); rc = -errno; @@ -747,8 +754,8 @@ static int fill_linkmap(mctp_nl *nl) addrlen = sizeof(addr); for (;;) { - rc = recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK, - NULL, NULL); + rc = mctp_ops.nl.recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK, + NULL, NULL); if (rc < 0) { warn("recvfrom(MSG_PEEK)"); break; @@ -768,8 +775,8 @@ static int fill_linkmap(mctp_nl *nl) buf = tmp; } - rc = recvfrom(nl->sd, buf, buflen, 0, - (struct sockaddr *)&addr, &addrlen); + rc = mctp_ops.nl.recvfrom(nl->sd, buf, buflen, 0, + (struct sockaddr *)&addr, &addrlen); if (rc < 0) { warn("recvfrom()"); break; diff --git a/src/mctp-ops.c b/src/mctp-ops.c new file mode 100644 index 0000000..bfb491d --- /dev/null +++ b/src/mctp-ops.c @@ -0,0 +1,72 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * mctp-ops: Abstraction for socket operations for mctp & mctpd. + * + * Copyright (c) 2023 Code Construct + */ + +#define _GNU_SOURCE + +#include +#include + +#include "mctp-ops.h" + +static int mctp_op_mctp_socket(void) +{ + return socket(AF_MCTP, SOCK_DGRAM, 0); +} + +static int mctp_op_netlink_socket(void) +{ + return socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); +} + +static int mctp_op_bind(int sd, struct sockaddr *addr, socklen_t addrlen) +{ + return bind(sd, addr, addrlen); +} + +static int mctp_op_setsockopt(int sd, int level, int optname, void *optval, + socklen_t optlen) +{ + return setsockopt(sd, level, optname, optval, optlen); +} + +static ssize_t mctp_op_sendto(int sd, const void *buf, size_t len, int flags, + const struct sockaddr *dest, socklen_t addrlen) +{ + return sendto(sd, buf, len, flags, dest, addrlen); +} + +static ssize_t mctp_op_recvfrom(int sd, void *buf, size_t len, int flags, + struct sockaddr *src, socklen_t *addrlen) +{ + return recvfrom(sd, buf, len, flags, src, addrlen); +} + +static int mctp_op_close(int sd) +{ + return close(sd); +} + +struct mctp_ops mctp_ops = { + .mctp = { + .socket = mctp_op_mctp_socket, + .setsockopt = mctp_op_setsockopt, + .bind = mctp_op_bind, + .sendto = mctp_op_sendto, + .recvfrom = mctp_op_recvfrom, + .close = mctp_op_close, + }, + .nl = { + .socket = mctp_op_netlink_socket, + .setsockopt = mctp_op_setsockopt, + .bind = mctp_op_bind, + .sendto = mctp_op_sendto, + .recvfrom = mctp_op_recvfrom, + .close = mctp_op_close, + }, +}; + +void mctp_ops_init(void) { } diff --git a/src/mctp-ops.h b/src/mctp-ops.h new file mode 100644 index 0000000..c5d4636 --- /dev/null +++ b/src/mctp-ops.h @@ -0,0 +1,33 @@ + +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * mctpd: bus owner for MCTP using Linux kernel + * + * Copyright (c) 2023 Code Construct + */ +#pragma once + +#include + +#define _GNU_SOURCE + +struct socket_ops { + int (*socket)(void); + int (*setsockopt)(int sd, int level, int optname, void *optval, + socklen_t optlen); + int (*bind)(int sd, struct sockaddr *addr, socklen_t addrlen); + ssize_t (*sendto)(int sd, const void *buf, size_t len, int flags, + const struct sockaddr *dest, socklen_t addrlen); + ssize_t (*recvfrom)(int sd, void *buf, size_t len, int flags, + struct sockaddr *src, socklen_t *addrlen); + int (*close)(int sd); +}; + +struct mctp_ops { + struct socket_ops mctp; + struct socket_ops nl; +}; + +extern struct mctp_ops mctp_ops; + +void mctp_ops_init(void); diff --git a/src/mctp.c b/src/mctp.c index 72ab7db..0dd13b3 100644 --- a/src/mctp.c +++ b/src/mctp.c @@ -33,6 +33,7 @@ #include "mctp.h" #include "mctp-util.h" #include "mctp-netlink.h" +#include "mctp-ops.h" #define max(a, b) ((a) > (b) ? (a) : (b)) #define min(a, b) ((a) < (b) ? (a) : (b)) @@ -769,6 +770,8 @@ static int cmd_addr_addremove(struct ctx *ctx, return -1; } + mctp_ops_init(); + eidstr = argv[1]; linkstr = argv[3]; diff --git a/src/mctpd.c b/src/mctpd.c index 000b4fd..9b30e1c 100644 --- a/src/mctpd.c +++ b/src/mctpd.c @@ -32,6 +32,7 @@ #include "mctp-util.h" #include "mctp-netlink.h" #include "mctp-control-spec.h" +#include "mctp-ops.h" #define max(a, b) ((a) > (b) ? (a) : (b)) #define min(a, b) ((a) < (b) ? (a) : (b)) @@ -417,7 +418,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz uint8_t* buf = NULL; size_t buf_size; - len = recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, NULL, 0); + len = mctp_ops.mctp.recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, + NULL, 0); if (len < 0) { rc = -errno; goto out; @@ -432,7 +434,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz addrlen = sizeof(struct sockaddr_mctp_ext); memset(ret_addr, 0x0, addrlen); - len = recvfrom(sd, buf, buf_size, MSG_TRUNC, (struct sockaddr *)ret_addr, + len = mctp_ops.mctp.recvfrom(sd, buf, buf_size, MSG_TRUNC, + (struct sockaddr *)ret_addr, &addrlen); if (len < 0) { rc = -errno; @@ -485,8 +488,9 @@ static int reply_message(ctx *ctx, int sd, const void *resp, size_t resp_len, return -EPROTO; } - len = sendto(sd, resp, resp_len, 0, - (struct sockaddr*)&reply_addr, sizeof(reply_addr)); + len = mctp_ops.mctp.sendto(sd, resp, resp_len, 0, + (struct sockaddr *)&reply_addr, + sizeof(reply_addr)); if (len < 0) { return -errno; } @@ -785,7 +789,7 @@ static int listen_control_msg(ctx *ctx, int net) struct sockaddr_mctp addr = { 0 }; int rc, sd = -1, val; - sd = socket(AF_MCTP, SOCK_DGRAM, 0); + sd = mctp_ops.mctp.socket(); if (sd < 0) { rc = -errno; warn("%s: socket() failed", __func__); @@ -798,7 +802,7 @@ static int listen_control_msg(ctx *ctx, int net) addr.smctp_type = MCTP_CTRL_HDR_MSG_TYPE; addr.smctp_tag = MCTP_TAG_OWNER; - rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr)); + rc = mctp_ops.mctp.bind(sd, (struct sockaddr *)&addr, sizeof(addr)); if (rc < 0) { rc = -errno; warn("%s: bind() failed", __func__); @@ -806,7 +810,8 @@ static int listen_control_msg(ctx *ctx, int net) } val = 1; - rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val)); + rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, + &val, sizeof(val)); if (rc < 0) { rc = -errno; warn("Kernel does not support MCTP extended addressing"); @@ -936,7 +941,7 @@ static int endpoint_query_addr(ctx *ctx, *resp = NULL; *resp_len = 0; - sd = socket(AF_MCTP, SOCK_DGRAM, 0); + sd = mctp_ops.mctp.socket(); if (sd < 0) { warn("socket"); rc = -errno; @@ -945,7 +950,8 @@ static int endpoint_query_addr(ctx *ctx, // We want extended addressing on all received messages val = 1; - rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val)); + rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, + &val, sizeof(val)); if (rc < 0) { rc = -errno; warn("Kernel does not support MCTP extended addressing"); @@ -963,7 +969,8 @@ static int endpoint_query_addr(ctx *ctx, rc = -EPROTO; goto out; } - rc = sendto(sd, req, req_len, 0, (struct sockaddr*)req_addr, req_addr_len); + rc = mctp_ops.mctp.sendto(sd, req, req_len, 0, + (struct sockaddr *)req_addr, req_addr_len); if (rc < 0) { rc = -errno; if (ctx->verbose) { @@ -3477,6 +3484,7 @@ int main(int argc, char **argv) setlinebuf(stdout); setup_config(ctx); + mctp_ops_init(); rc = parse_args(ctx, argc, argv); if (rc != 0) {