Skip to content

Commit

Permalink
mctp: introduce ops wrapper for socket operations
Browse files Browse the repository at this point in the history
We'll want to be able to mock the kernel interactions for testing, so
call all of the socket operations (on both netlink and MCTP sockets)
through a central set of funtion pointers.

This will allow us to override with mock implementations for testing.

Signed-off-by: Jeremy Kerr <[email protected]>
  • Loading branch information
jk-ozlabs committed Jan 12, 2024
1 parent 17207f3 commit ee9b1af
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 31 deletions.
5 changes: 3 additions & 2 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ config_h = configure_file(

util_sources = ['src/mctp-util.c']
netlink_sources = ['src/mctp-netlink.c']
ops_sources = ['src/mctp-ops.c']

executable('mctp',
sources: ['src/mctp.c'] + netlink_sources + util_sources,
sources: ['src/mctp.c'] + netlink_sources + util_sources + ops_sources,
install: true,
)

Expand All @@ -53,7 +54,7 @@ if libsystemd.found()
executable('mctpd',
sources: [
'src/mctpd.c',
] + netlink_sources + util_sources,
] + netlink_sources + util_sources + ops_sources,
dependencies: libsystemd,
install: true,
install_dir: get_option('sbindir'),
Expand Down
45 changes: 26 additions & 19 deletions src/mctp-netlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mctp-netlink.h"
#include "mctp.h"
#include "mctp-util.h"
#include "mctp-ops.h"

struct linkmap_entry {
int ifindex;
Expand Down Expand Up @@ -65,26 +66,27 @@ static int open_nl_socket(void)
struct sockaddr_nl addr;
int opt, rc, sd = -1;

rc = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
rc = mctp_ops.nl.socket();
if (rc < 0)
goto err;
sd = rc;
memset(&addr, 0, sizeof(addr));
addr.nl_family = AF_NETLINK;
rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr));
rc = mctp_ops.nl.bind(sd, (struct sockaddr *)&addr, sizeof(addr));
if (rc)
goto err;

opt = 1;
rc = setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK,
&opt, sizeof(opt));
rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_GET_STRICT_CHK,
&opt, sizeof(opt));
if (rc) {
rc = -errno;
goto err;
}

opt = 1;
rc = setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt));
rc = mctp_ops.nl.setsockopt(sd, SOL_NETLINK, NETLINK_EXT_ACK, &opt,
sizeof(opt));
if (rc)
{
rc = -errno;
Expand Down Expand Up @@ -144,8 +146,8 @@ static void free_linkmap(struct linkmap_entry *linkmap, size_t count)
void mctp_nl_close(mctp_nl *nl)
{
free_linkmap(nl->linkmap, nl->linkmap_count);
close(nl->sd);
close(nl->sd_monitor);
mctp_ops.nl.close(nl->sd);
mctp_ops.nl.close(nl->sd_monitor);
free(nl);
}

Expand All @@ -164,14 +166,18 @@ int mctp_nl_monitor(mctp_nl *nl, bool enable)
return nl->sd_monitor;

opt = RTNLGRP_LINK;
rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt));
rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK,
NETLINK_ADD_MEMBERSHIP,
&opt, sizeof(opt));
if (rc < 0) {
rc = -errno;
goto err;
}

opt = RTNLGRP_MCTP_IFADDR;
rc = setsockopt(nl->sd_monitor, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &opt, sizeof(opt));
rc = mctp_ops.nl.setsockopt(nl->sd_monitor, SOL_NETLINK,
NETLINK_ADD_MEMBERSHIP,
&opt, sizeof(opt));
if (rc < 0) {
rc = -errno;
if (errno == EINVAL) {
Expand Down Expand Up @@ -492,7 +498,7 @@ static int handle_nlmsg_ack(mctp_nl *nl)
int rc;
size_t len;

rc = recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL);
rc = mctp_ops.nl.recvfrom(nl->sd, resp, sizeof(resp), 0, NULL, NULL);
if (rc < 0)
return rc;
len = rc;
Expand Down Expand Up @@ -535,8 +541,8 @@ int mctp_nl_send(mctp_nl *nl, struct nlmsghdr *msg)
addr.nl_family = AF_NETLINK;
addr.nl_pid = 0;

rc = sendto(nl->sd, msg, msg->nlmsg_len, 0,
(struct sockaddr *)&addr, sizeof(addr));
rc = mctp_ops.nl.sendto(nl->sd, msg, msg->nlmsg_len, 0,
(struct sockaddr *)&addr, sizeof(addr));
if (rc < 0)
return rc;

Expand Down Expand Up @@ -585,7 +591,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd,

// read all the responses into a single buffer
while (!done) {
rc = recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC, NULL, 0);
rc = mctp_ops.nl.recvfrom(sd, NULL, 0, MSG_PEEK|MSG_TRUNC,
NULL, 0);
if (rc < 0) {
warnx("recvfrom(MSG_PEEK)");
rc = -errno;
Expand Down Expand Up @@ -614,8 +621,8 @@ int mctp_nl_recv_all(mctp_nl *nl, int sd,
resp = respbuf + pos;

addrlen = sizeof(addr);
rc = recvfrom(sd, resp, readlen, MSG_TRUNC,
(struct sockaddr *)&addr, &addrlen);
rc = mctp_ops.nl.recvfrom(sd, resp, readlen, MSG_TRUNC,
(struct sockaddr *)&addr, &addrlen);
if (rc < 0) {
warnx("recvfrom(MSG_PEEK)");
rc = -errno;
Expand Down Expand Up @@ -747,8 +754,8 @@ static int fill_linkmap(mctp_nl *nl)
addrlen = sizeof(addr);

for (;;) {
rc = recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK,
NULL, NULL);
rc = mctp_ops.nl.recvfrom(nl->sd, NULL, 0, MSG_TRUNC | MSG_PEEK,
NULL, NULL);
if (rc < 0) {
warn("recvfrom(MSG_PEEK)");
break;
Expand All @@ -768,8 +775,8 @@ static int fill_linkmap(mctp_nl *nl)
buf = tmp;
}

rc = recvfrom(nl->sd, buf, buflen, 0,
(struct sockaddr *)&addr, &addrlen);
rc = mctp_ops.nl.recvfrom(nl->sd, buf, buflen, 0,
(struct sockaddr *)&addr, &addrlen);
if (rc < 0) {
warn("recvfrom()");
break;
Expand Down
72 changes: 72 additions & 0 deletions src/mctp-ops.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* SPDX-License-Identifier: GPL-2.0 */
/*
* mctp-ops: Abstraction for socket operations for mctp & mctpd.
*
* Copyright (c) 2023 Code Construct
*/

#define _GNU_SOURCE

#include <unistd.h>
#include <linux/netlink.h>

#include "mctp-ops.h"

static int mctp_op_mctp_socket(void)
{
return socket(AF_MCTP, SOCK_DGRAM, 0);
}

static int mctp_op_netlink_socket(void)
{
return socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
}

static int mctp_op_bind(int sd, struct sockaddr *addr, socklen_t addrlen)
{
return bind(sd, addr, addrlen);
}

static int mctp_op_setsockopt(int sd, int level, int optname, void *optval,
socklen_t optlen)
{
return setsockopt(sd, level, optname, optval, optlen);
}

static ssize_t mctp_op_sendto(int sd, const void *buf, size_t len, int flags,
const struct sockaddr *dest, socklen_t addrlen)
{
return sendto(sd, buf, len, flags, dest, addrlen);
}

static ssize_t mctp_op_recvfrom(int sd, void *buf, size_t len, int flags,
struct sockaddr *src, socklen_t *addrlen)
{
return recvfrom(sd, buf, len, flags, src, addrlen);
}

static int mctp_op_close(int sd)
{
return close(sd);
}

struct mctp_ops mctp_ops = {
.mctp = {
.socket = mctp_op_mctp_socket,
.setsockopt = mctp_op_setsockopt,
.bind = mctp_op_bind,
.sendto = mctp_op_sendto,
.recvfrom = mctp_op_recvfrom,
.close = mctp_op_close,
},
.nl = {
.socket = mctp_op_netlink_socket,
.setsockopt = mctp_op_setsockopt,
.bind = mctp_op_bind,
.sendto = mctp_op_sendto,
.recvfrom = mctp_op_recvfrom,
.close = mctp_op_close,
},
};

void mctp_ops_init(void) { }
33 changes: 33 additions & 0 deletions src/mctp-ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

/* SPDX-License-Identifier: GPL-2.0 */
/*
* mctpd: bus owner for MCTP using Linux kernel
*
* Copyright (c) 2023 Code Construct
*/
#pragma once

#include <sys/socket.h>

#define _GNU_SOURCE

struct socket_ops {
int (*socket)(void);
int (*setsockopt)(int sd, int level, int optname, void *optval,
socklen_t optlen);
int (*bind)(int sd, struct sockaddr *addr, socklen_t addrlen);
ssize_t (*sendto)(int sd, const void *buf, size_t len, int flags,
const struct sockaddr *dest, socklen_t addrlen);
ssize_t (*recvfrom)(int sd, void *buf, size_t len, int flags,
struct sockaddr *src, socklen_t *addrlen);
int (*close)(int sd);
};

struct mctp_ops {
struct socket_ops mctp;
struct socket_ops nl;
};

extern struct mctp_ops mctp_ops;

void mctp_ops_init(void);
3 changes: 3 additions & 0 deletions src/mctp.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mctp.h"
#include "mctp-util.h"
#include "mctp-netlink.h"
#include "mctp-ops.h"

#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -769,6 +770,8 @@ static int cmd_addr_addremove(struct ctx *ctx,
return -1;
}

mctp_ops_init();

eidstr = argv[1];
linkstr = argv[3];

Expand Down
28 changes: 18 additions & 10 deletions src/mctpd.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "mctp-util.h"
#include "mctp-netlink.h"
#include "mctp-control-spec.h"
#include "mctp-ops.h"

#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -417,7 +418,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz
uint8_t* buf = NULL;
size_t buf_size;

len = recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC, NULL, 0);
len = mctp_ops.mctp.recvfrom(sd, NULL, 0, MSG_PEEK | MSG_TRUNC,
NULL, 0);
if (len < 0) {
rc = -errno;
goto out;
Expand All @@ -432,7 +434,8 @@ static int read_message(ctx *ctx, int sd, uint8_t **ret_buf, size_t *ret_buf_siz

addrlen = sizeof(struct sockaddr_mctp_ext);
memset(ret_addr, 0x0, addrlen);
len = recvfrom(sd, buf, buf_size, MSG_TRUNC, (struct sockaddr *)ret_addr,
len = mctp_ops.mctp.recvfrom(sd, buf, buf_size, MSG_TRUNC,
(struct sockaddr *)ret_addr,
&addrlen);
if (len < 0) {
rc = -errno;
Expand Down Expand Up @@ -485,8 +488,9 @@ static int reply_message(ctx *ctx, int sd, const void *resp, size_t resp_len,
return -EPROTO;
}

len = sendto(sd, resp, resp_len, 0,
(struct sockaddr*)&reply_addr, sizeof(reply_addr));
len = mctp_ops.mctp.sendto(sd, resp, resp_len, 0,
(struct sockaddr *)&reply_addr,
sizeof(reply_addr));
if (len < 0) {
return -errno;
}
Expand Down Expand Up @@ -785,7 +789,7 @@ static int listen_control_msg(ctx *ctx, int net)
struct sockaddr_mctp addr = { 0 };
int rc, sd = -1, val;

sd = socket(AF_MCTP, SOCK_DGRAM, 0);
sd = mctp_ops.mctp.socket();
if (sd < 0) {
rc = -errno;
warn("%s: socket() failed", __func__);
Expand All @@ -798,15 +802,16 @@ static int listen_control_msg(ctx *ctx, int net)
addr.smctp_type = MCTP_CTRL_HDR_MSG_TYPE;
addr.smctp_tag = MCTP_TAG_OWNER;

rc = bind(sd, (struct sockaddr *)&addr, sizeof(addr));
rc = mctp_ops.mctp.bind(sd, (struct sockaddr *)&addr, sizeof(addr));
if (rc < 0) {
rc = -errno;
warn("%s: bind() failed", __func__);
goto out;
}

val = 1;
rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val));
rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT,
&val, sizeof(val));
if (rc < 0) {
rc = -errno;
warn("Kernel does not support MCTP extended addressing");
Expand Down Expand Up @@ -936,7 +941,7 @@ static int endpoint_query_addr(ctx *ctx,
*resp = NULL;
*resp_len = 0;

sd = socket(AF_MCTP, SOCK_DGRAM, 0);
sd = mctp_ops.mctp.socket();
if (sd < 0) {
warn("socket");
rc = -errno;
Expand All @@ -945,7 +950,8 @@ static int endpoint_query_addr(ctx *ctx,

// We want extended addressing on all received messages
val = 1;
rc = setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT, &val, sizeof(val));
rc = mctp_ops.mctp.setsockopt(sd, SOL_MCTP, MCTP_OPT_ADDR_EXT,
&val, sizeof(val));
if (rc < 0) {
rc = -errno;
warn("Kernel does not support MCTP extended addressing");
Expand All @@ -963,7 +969,8 @@ static int endpoint_query_addr(ctx *ctx,
rc = -EPROTO;
goto out;
}
rc = sendto(sd, req, req_len, 0, (struct sockaddr*)req_addr, req_addr_len);
rc = mctp_ops.mctp.sendto(sd, req, req_len, 0,
(struct sockaddr *)req_addr, req_addr_len);
if (rc < 0) {
rc = -errno;
if (ctx->verbose) {
Expand Down Expand Up @@ -3477,6 +3484,7 @@ int main(int argc, char **argv)
setlinebuf(stdout);

setup_config(ctx);
mctp_ops_init();

rc = parse_args(ctx, argc, argv);
if (rc != 0) {
Expand Down

0 comments on commit ee9b1af

Please sign in to comment.