Skip to content

Commit

Permalink
zeromq#4494 added calls to snprintf, but did not take into account th…
Browse files Browse the repository at this point in the history
…at snprintf

can truncate, and then return the number of characters that would have been
written without truncation. Replace such calls with calls to a zmq_snprintf
helper that asserts in the case of error or truncation.

Signed-off-by: Daira Hopwood <[email protected]>
  • Loading branch information
daira committed Feb 1, 2023
1 parent 333c88e commit b1737ec
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 33 deletions.
2 changes: 2 additions & 0 deletions include/zmq.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ ZMQ_EXPORT void *zmq_threadstart (zmq_thread_fn *func_, void *arg_);
/* Wait for thread to complete then free up resources. */
ZMQ_EXPORT void zmq_threadclose (void *thread_);

/* Equivalent of snprintf that asserts on error or truncation. */
ZMQ_EXPORT int zmq_snprintf ( char *restrict buffer, size_t bufsz, const char *restrict format, ... );

/******************************************************************************/
/* These functions are DRAFT and disabled in stable releases, and subject to */
Expand Down
3 changes: 1 addition & 2 deletions src/tcp_address.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ static std::string make_address_string (const char *hbuf_,
pos += hbuf_len;
memcpy (pos, ipv6_suffix_, sizeof ipv6_suffix_ - 1);
pos += sizeof ipv6_suffix_ - 1;
pos += snprintf (pos, max_port_str_length + 1 * sizeof (char), "%d",
ntohs (port_));
pos += zmq_snprintf (pos, max_port_str_length + 1, "%d", ntohs (port_));
return std::string (buf, pos - buf);
}

Expand Down
5 changes: 2 additions & 3 deletions src/udp_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,8 @@ void zmq::udp_engine_t::sockaddr_to_msg (zmq::msg_t *msg_,
const char *const name = inet_ntoa (addr_->sin_addr);

char port[6];
const int port_len = snprintf (port, 6 * sizeof (char), "%d",
static_cast<int> (ntohs (addr_->sin_port)));
zmq_assert (port_len > 0);
const int port_len = zmq_snprintf (port, 6, "%d",
static_cast<int> (ntohs (addr_->sin_port)));

const size_t name_len = strlen (name);
const int size = static_cast<int> (name_len) + 1 /* colon */
Expand Down
17 changes: 17 additions & 0 deletions src/zmq_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "random.hpp"
#include <assert.h>
#include <new>
#include <cstdio>

#if !defined ZMQ_HAVE_WINDOWS
#include <unistd.h>
Expand Down Expand Up @@ -324,3 +325,19 @@ void zmq_atomic_counter_destroy (void **counter_p_)
delete (static_cast<zmq::atomic_counter_t *> (*counter_p_));
*counter_p_ = NULL;
}

// This function behaves like snprintf in the successful case with no truncation,
// but asserts if either snprintf either returns a negative value (indicating an
// output or encoding error) or truncates. It is similar to sprintf_s but uses
// zmq_assert, and does not require __STDC_WANT_LIB_EXT1__ to be defined.

int zmq_snprintf ( char *restrict buffer, size_t bufsz, const char *restrict format, ... )
{
std::va_list args;
va_start(args, format);
int res = vsnprintf (buffer, bufsz, format, args);
zmq_assert (res >= 0);
zmq_assert (res < bufsz);
va_end(args);
return res;
}
4 changes: 2 additions & 2 deletions tests/test_inproc_connect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void test_connect_before_bind_ctx_term ()
void *connect_socket = test_context_socket (ZMQ_ROUTER);

char ep[32];
snprintf (ep, 32 * sizeof (char), "inproc://cbbrr%d", i);
zmq_snprintf (ep, 32, "inproc://cbbrr%d", i);
TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (connect_socket, ep));

// Cleanup
Expand Down Expand Up @@ -233,7 +233,7 @@ void test_simultaneous_connect_bind_threads ()
// Set up thread arguments: context followed by endpoint string
for (unsigned int i = 0; i < no_of_times; ++i) {
thr_args[i] = (void *) endpts[i];
snprintf (endpts[i], 20 * sizeof (char), "inproc://foo_%d", i);
zmq_snprintf (endpts[i], 20, "inproc://foo_%d", i);
}

// Spawn all threads as simultaneously as possible
Expand Down
2 changes: 1 addition & 1 deletion tests/test_issue_566.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void test_issue_566 ()
void *dealer = zmq_socket (ctx2, ZMQ_DEALER);
// Leave space for NULL char from sprintf, gcc warning
char routing_id[11];
snprintf (routing_id, 11 * sizeof (char), "%09d", cycle);
zmq_snprintf (routing_id, 11, "%09d", cycle);
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (dealer, ZMQ_ROUTING_ID, routing_id, 10));
int rcvtimeo = 1000;
Expand Down
14 changes: 6 additions & 8 deletions tests/test_proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ static void client_task (void *db_)
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (endpoint, ZMQ_LINGER, &linger, sizeof (linger)));
char endpoint_source[256];
snprintf (endpoint_source, 256 * sizeof (char), "inproc://endpoint%d",
databag->id);
zmq_snprintf (endpoint_source, 256, "inproc://endpoint%d", databag->id);
TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (endpoint, endpoint_source));
char *my_endpoint = s_recv (endpoint);
TEST_ASSERT_NOT_NULL (my_endpoint);
Expand All @@ -108,8 +107,8 @@ static void client_task (void *db_)
char content[CONTENT_SIZE_MAX] = {};
// Set random routing id to make tracing easier
char routing_id[ROUTING_ID_SIZE] = {};
snprintf (routing_id, ROUTING_ID_SIZE * sizeof (char), "%04X-%04X",
rand () % 0xFFFF, rand () % 0xFFFF);
zmq_snprintf (routing_id, ROUTING_ID_SIZE), "%04X-%04X",
rand () % 0xFFFF, rand () % 0xFFFF);
TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (
client, ZMQ_ROUTING_ID, routing_id,
ROUTING_ID_SIZE)); // includes '\0' as an helper for printf
Expand Down Expand Up @@ -166,8 +165,8 @@ static void client_task (void *db_)
}

if (keep_sending) {
snprintf (content, CONTENT_SIZE_MAX * sizeof (char),
"request #%03d", ++request_nbr); // CONTENT_SIZE
zmq_snprintf (content, CONTENT_SIZE_MAX,
"request #%03d", ++request_nbr); // CONTENT_SIZE
if (is_verbose)
printf ("client send - routing_id = %s request #%03d\n",
routing_id, request_nbr);
Expand Down Expand Up @@ -231,8 +230,7 @@ void server_task (void * /*unused_*/)
TEST_ASSERT_NOT_NULL (endpoint_receivers[i]);
TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (
endpoint_receivers[i], ZMQ_LINGER, &linger, sizeof (linger)));
snprintf (endpoint_source, 256 * sizeof (char), "inproc://endpoint%d",
i);
zmq_snprintf (endpoint_source, 256, "inproc://endpoint%d", i);
TEST_ASSERT_SUCCESS_ERRNO (
zmq_bind (endpoint_receivers[i], endpoint_source));
}
Expand Down
8 changes: 4 additions & 4 deletions tests/test_reqrep_tcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ void make_connect_address (char *connect_address_,
const char *bind_address_)
{
if (ipv6_)
snprintf (connect_address_, 30 * sizeof (char), "tcp://[::1]:%i;%s",
port_, strrchr (bind_address_, '/') + 1);
zmq_snprintf (connect_address_, 30, "tcp://[::1]:%i;%s",
port_, strrchr (bind_address_, '/') + 1);
else
snprintf (connect_address_, 38 * sizeof (char), "tcp://127.0.0.1:%i;%s",
port_, strrchr (bind_address_, '/') + 1);
zmq_snprintf (connect_address_, 38, "tcp://127.0.0.1:%i;%s",
port_, strrchr (bind_address_, '/') + 1);
}

void test_multi_connect (int ipv6_)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_setsockopt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void test_setsockopt_bindtodevice ()
TEST_ASSERT_EQUAL_INT8 ('\0', devname[0]);
TEST_ASSERT_EQUAL_UINT (1, buflen);

snprintf (devname, BOUNDDEVBUFSZ * sizeof (char), "testdev");
zmq_snprintf (devname, BOUNDDEVBUFSZ, "testdev");
buflen = strlen (devname);

TEST_ASSERT_SUCCESS_ERRNO (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_stream_disconnect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ void test_stream_disconnect ()

// Apparently Windows can't connect to 0.0.0.0. A better fix would be welcome.
#ifdef ZMQ_HAVE_WINDOWS
snprintf (connect_endpoint, MAX_SOCKET_STRING * sizeof (char),
"tcp://127.0.0.1:%s", strrchr (bind_endpoint, ':') + 1);
zmq_snprintf (connect_endpoint, MAX_SOCKET_STRING,
"tcp://127.0.0.1:%s", strrchr (bind_endpoint, ':') + 1);
#else
strcpy (connect_endpoint, bind_endpoint);
#endif
Expand Down
12 changes: 6 additions & 6 deletions tests/test_unbind_wildcard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void test_address_wildcard_ipv4 ()

// Apparently Windows can't connect to 0.0.0.0. A better fix would be welcome.
#ifdef ZMQ_HAVE_WINDOWS
snprintf (connect_endpoint, 256 * sizeof (char), "tcp://127.0.0.1:%s",
strrchr (bind_endpoint, ':') + 1);
zmq_snprintf (connect_endpoint, 256, "tcp://127.0.0.1:%s",
strrchr (bind_endpoint, ':') + 1);
#else
strcpy (connect_endpoint, bind_endpoint);
#endif
Expand Down Expand Up @@ -81,11 +81,11 @@ void test_address_wildcard_ipv6 ()

#ifdef ZMQ_HAVE_WINDOWS
if (ipv6)
snprintf (connect_endpoint, 256 * sizeof (char), "tcp://[::1]:%s",
strrchr (bind_endpoint, ':') + 1);
zmq_snprintf (connect_endpoint, 256, "tcp://[::1]:%s",
strrchr (bind_endpoint, ':') + 1);
else
snprintf (connect_endpoint, 256 * sizeof (char), "tcp://127.0.0.1:%s",
strrchr (bind_endpoint, ':') + 1);
zmq_snprintf (connect_endpoint, 256, "tcp://127.0.0.1:%s",
strrchr (bind_endpoint, ':') + 1);
#else
strcpy (connect_endpoint, bind_endpoint);
#endif
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ws_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void test_roundtrip ()
zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, bind_address, &addr_length));

// Windows can't connect to 0.0.0.0
snprintf (connect_address, MAX_SOCKET_STRING * sizeof (char),
"ws://127.0.0.1%s", strrchr (bind_address, ':'));
zmq_snprintf (connect_address, MAX_SOCKET_STRING,
"ws://127.0.0.1%s", strrchr (bind_address, ':'));

TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address));

Expand Down
4 changes: 2 additions & 2 deletions tests/testutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ fd_t bind_socket_resolve_port (const char *address_,
addr_len = sizeof (struct sockaddr_storage);
TEST_ASSERT_SUCCESS_RAW_ERRNO (
getsockname (s_pre, (struct sockaddr *) &addr, &addr_len));
snprintf (
my_endpoint_, 6 + strlen (address_) + 7 * sizeof (char), "%s://%s:%u",
zmq_snprintf (
my_endpoint_, 6 + strlen (address_) + 7, "%s://%s:%u",
protocol_ == IPPROTO_TCP ? "tcp"
: protocol_ == IPPROTO_UDP ? "udp"
: protocol_ == IPPROTO_WSS ? "wss"
Expand Down

0 comments on commit b1737ec

Please sign in to comment.