Skip to content

Commit

Permalink
UCP/AM: CR fixes openucx#3
Browse files Browse the repository at this point in the history
  • Loading branch information
amastbaum committed Jun 18, 2024
1 parent 0da08c2 commit f273e12
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 88 deletions.
2 changes: 1 addition & 1 deletion src/ucp/api/ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@ typedef struct ucp_am_handler_param {

/**
* Active Message id.
* @warning Values greater than uint16_t are not supported.
* @warning Value must be between 0 and UINT16_MAX.
*/
unsigned id;

Expand Down
6 changes: 3 additions & 3 deletions src/ucp/core/ucp_am.c
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ ucs_status_t ucp_worker_set_am_recv_handler(ucp_worker_h worker,
}

if (param->id > UINT16_MAX) {
ucs_error("AM handler id %u is outside of allowed range [0, %u]",
param->id, UINT16_MAX);
ucs_error("Invalid AM id %u, must be in range [0, %u]",
param->id, UINT16_MAX);
return UCS_ERR_INVALID_PARAM;
}

Expand Down Expand Up @@ -959,7 +959,7 @@ UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_am_send_nbx,

if (id > UINT16_MAX) {
ret = UCS_STATUS_PTR(UCS_ERR_INVALID_PARAM);
ucs_error("AM handler id %u is outside of allowed range [0, %u]",
ucs_error("Invalid AM id %u, must be in range [0, %u]",
id, UINT16_MAX);
goto out;
}
Expand Down
135 changes: 51 additions & 84 deletions test/gtest/ucp/test_ucp_am.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ class test_ucp_am_nbx : public test_ucp_am_base {
}
}

void set_am_data_handler(entity &e, uint16_t am_id,
ucp_am_recv_callback_t cb, void *arg,
unsigned flags = 0)
ucs_status_t set_am_data_handler(entity &e, unsigned am_id,
ucp_am_recv_callback_t cb, void *arg,
unsigned flags = 0)
{
ucp_am_handler_param_t param;

Expand All @@ -437,7 +437,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
param.flags = flags;
}

ASSERT_UCS_OK(ucp_worker_set_am_recv_handler(e.worker(), &param));
return ucp_worker_set_am_recv_handler(e.worker(), &param);
}

void check_header(const void *header, size_t header_length)
Expand All @@ -448,11 +448,11 @@ class test_ucp_am_nbx : public test_ucp_am_base {

ucs_status_ptr_t
update_counter_and_send_am(const void *header, size_t header_length,
const void *buffer, size_t count,
const void *buffer, size_t count, unsigned am_id,
const ucp_request_param_t *param)
{
m_send_counter++;
return ucp_am_send_nbx(sender().ep(), TEST_AM_NBX_ID, header,
return ucp_am_send_nbx(sender().ep(), am_id, header,
header_length, buffer, count, param);
}

Expand All @@ -478,6 +478,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
ucs_status_ptr_t sptr = update_counter_and_send_am(hdr, hdr_length,
dt_desc.buf(),
dt_desc.count(),
TEST_AM_NBX_ID,
&param);
return sptr;
}
Expand All @@ -493,8 +494,8 @@ class test_ucp_am_nbx : public test_ucp_am_base {
reset_counters();
ucp_mem_h memh = NULL;

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_cb, this,
data_cb_flags);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_cb, this, data_cb_flags));

ucp::data_type_desc_t sdt_desc(m_dt, sbuf.ptr(), size);

Expand Down Expand Up @@ -693,94 +694,51 @@ class test_ucp_am_nbx : public test_ucp_am_base {
};

class test_ucp_am_id : public test_ucp_am_nbx {
public:
test_ucp_am_id()
{
reset_counters();
}

protected:

void reset_counters()
{
test_ucp_am_nbx::reset_counters();
m_recv_counter_cb_1 = 0;
m_recv_counter_cb_2 = 0;
}

ucs_status_t set_am_id_overflow_data_handler(entity &e, unsigned am_id,
ucp_am_recv_callback_t cb)
{
ucp_am_handler_param_t param;

/* Initialize Active Message data handler */
param.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID |
UCP_AM_HANDLER_PARAM_FIELD_CB |
UCP_AM_HANDLER_PARAM_FIELD_ARG;
param.id = am_id;
param.cb = cb;
param.arg = this;

return ucp_worker_set_am_recv_handler(e.worker(), &param);
}

void test_am_id_handler(size_t size, size_t header_size = 0ul)
void test_am_id_handler()
{
reset_counters();

EXPECT_EQ(UCS_OK, set_am_id_overflow_data_handler(receiver(),
0x1, am_id_overflow_data_cb_1));
EXPECT_NE(UCS_OK, set_am_id_overflow_data_handler(receiver(),
0xffff0001, am_id_overflow_data_cb_2));
EXPECT_UCS_OK(test_ucp_am_nbx::set_am_data_handler(
receiver(), 0x1,
am_id_overflow_data_cb_1, this));
EXPECT_NE(UCS_OK, test_ucp_am_nbx::set_am_data_handler(
receiver(), 0xffff0001,
am_id_overflow_data_cb_2, this));

ucp_request_param_t param;
param.op_attr_mask = 0;

m_send_counter++;
ucs_status_ptr_t sptr = ucp_am_send_nbx(sender().ep(), 0x1, NULL,
0ul, NULL, 0, &param);
ucs_status_ptr_t sptr = update_counter_and_send_am(
NULL, 0ul, NULL, 0, 0x1, &param);
wait_receives();
EXPECT_EQ(UCS_OK, request_wait(sptr));
ASSERT_UCS_OK(request_wait(sptr));
EXPECT_EQ(m_recv_counter_cb_1, 1);
EXPECT_EQ(m_recv_counter_cb_2, 0);

m_send_counter++;
sptr = ucp_am_send_nbx(sender().ep(), 0xffff0001,
NULL, 0ul, NULL, 0, &param);
sptr = update_counter_and_send_am(NULL, 0ul, NULL, 0, 0xffff0001, &param);
EXPECT_EQ(UCS_PTR_STATUS(sptr), UCS_ERR_INVALID_PARAM);
EXPECT_EQ(m_recv_counter_cb_1, 1);
EXPECT_EQ(m_recv_counter_cb_2, 0);
}

virtual ucs_status_t am_id_overflow_data_handler(
const void *header, size_t header_length,
void *data, size_t length,
const ucp_am_recv_param_t *rx_param,
int handler_num)
{
test_ucp_am_nbx::check_header(header, header_length);
m_recv_counter++;

bool has_reply_ep = get_send_flag();

EXPECT_EQ(has_reply_ep,
!!(rx_param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP));
EXPECT_EQ(has_reply_ep, rx_param->reply_ep != NULL);

if (1 == handler_num) m_recv_counter_cb_1++;
else if (2 == handler_num) m_recv_counter_cb_2++;

return UCS_OK;
}

static ucs_status_t am_id_overflow_data_cb_1(void *arg, const void *header,
size_t header_length,
void *data, size_t length,
const ucp_am_recv_param_t *param)
{
test_ucp_am_id *self = reinterpret_cast<test_ucp_am_id*>(arg);
return self->am_id_overflow_data_handler(header, header_length,
data, length, param, 1);
self->m_recv_counter_cb_1++;
return self->am_data_handler(header, header_length,
data, length, param);

}

static ucs_status_t am_id_overflow_data_cb_2(void *arg, const void *header,
Expand All @@ -789,8 +747,9 @@ class test_ucp_am_id : public test_ucp_am_nbx {
const ucp_am_recv_param_t *param)
{
test_ucp_am_id *self = reinterpret_cast<test_ucp_am_id*>(arg);
return self->am_id_overflow_data_handler(header, header_length,
data, length, param, 2);
self->m_recv_counter_cb_2++;
return self->am_data_handler(header, header_length,
data, length, param);
}

volatile size_t m_recv_counter_cb_1;
Expand Down Expand Up @@ -905,7 +864,7 @@ UCS_TEST_P(test_ucp_am_nbx, am_header_error)
UCS_TEST_P(test_ucp_am_id, am_id_overflow)
{
scoped_log_handler wrap_err(wrap_errors_logger);
test_am_id_handler(0, max_am_hdr());
test_am_id_handler();
}

UCP_INSTANTIATE_TEST_CASE(test_ucp_am_id)
Expand All @@ -920,8 +879,9 @@ UCS_TEST_P(test_ucp_am_nbx, rx_persistent_data)
void *rx_data = NULL;
char data = 'd';

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_hold_cb, &rx_data,
UCP_AM_FLAG_PERSISTENT_DATA);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_hold_cb, &rx_data,
UCP_AM_FLAG_PERSISTENT_DATA));

ucp_request_param_t param;

Expand Down Expand Up @@ -974,8 +934,9 @@ UCS_TEST_P(test_ucp_am_nbx, rx_am_mpools,
{
void *rx_data = NULL;

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_hold_cb, &rx_data,
UCP_AM_FLAG_PERSISTENT_DATA);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_hold_cb, &rx_data,
UCP_AM_FLAG_PERSISTENT_DATA));

static const std::string ib_tls[] = { "dc_x", "rc_v", "rc_x", "ud_v",
"ud_x", "ib" };
Expand Down Expand Up @@ -1107,7 +1068,8 @@ class test_ucp_am_nbx_send_copy_header : public test_ucp_am_nbx {
m_hdr = m_hdr_copy;
reset_counters();

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_cb, this);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_cb, this));
/**
* For RNDV we use 8 byte length to fill the SQ
* so we will not get IN_PROGRESS status from fill_sq.
Expand Down Expand Up @@ -1334,7 +1296,8 @@ class test_ucp_am_nbx_closed_ep : public test_ucp_am_nbx_reply {
std::vector<char> sbuf(size, 'd');
ucp::data_type_desc_t sdt_desc(m_dt, &sbuf[0], size);

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_rx_check_cb, this);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_rx_check_cb, this));

ucs_status_ptr_t sreq = send_am(sdt_desc, get_send_flag());

Expand Down Expand Up @@ -1753,8 +1716,8 @@ class test_ucp_am_nbx_rndv : public test_ucp_am_nbx_prereg {
sbuf.pattern_fill(SEED);

struct am_cb_args args = { this, &data_desc };
set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_deferred_rndv_cb, &args, 0);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_deferred_rndv_cb, &args, 0));

if (prereg()) {
memh = sender().mem_map(sbuf.ptr(), size);
Expand Down Expand Up @@ -1841,11 +1804,13 @@ UCS_TEST_SKIP_COND_P(test_ucp_am_nbx_rndv, invalid_recv_desc,
ucp_request_param_t param;

struct am_cb_args args = { this, &data_desc };
set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_drop_rndv_cb, &args);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_drop_rndv_cb, &args));

param.op_attr_mask = 0ul;
ucs_status_ptr_t sptr = update_counter_and_send_am(NULL, 0ul, &data,
sizeof(data), &param);
sizeof(data),
TEST_AM_NBX_ID, &param);

wait_receives();

Expand All @@ -1866,8 +1831,8 @@ UCS_TEST_P(test_ucp_am_nbx_rndv, reject_rndv)
{
skip_loopback();

set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_reject_rndv_cb,
this);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_reject_rndv_cb, this));

std::vector<char> sbuf(10000, 0);
ucp_request_param_t param;
Expand All @@ -1882,7 +1847,9 @@ UCS_TEST_P(test_ucp_am_nbx_rndv, reject_rndv)

ucs_status_ptr_t sptr = update_counter_and_send_am(NULL, 0ul,
sbuf.data(),
sbuf.size(), &param);
sbuf.size(),
TEST_AM_NBX_ID,
&param);

EXPECT_EQ(m_status, request_wait(sptr));
EXPECT_EQ(m_recv_counter, m_send_counter);
Expand All @@ -1900,8 +1867,8 @@ UCS_TEST_P(test_ucp_am_nbx_rndv, deferred_reject_rndv)
param.op_attr_mask = 0ul;

struct am_cb_args args = { this, &data_desc };
set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_deferred_rndv_cb,
&args);
EXPECT_UCS_OK(set_am_data_handler(receiver(), TEST_AM_NBX_ID,
am_data_deferred_rndv_cb, &args));

ucs_status_ptr_t sptr = ucp_am_send_nbx(sender().ep(), TEST_AM_NBX_ID,
NULL, 0ul, sbuf.data(),
Expand Down

0 comments on commit f273e12

Please sign in to comment.