Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCRAM plus variants #228

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 122 additions & 38 deletions src/auth.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn,
static int _handle_scram_challenge(xmpp_conn_t *conn,
xmpp_stanza_t *stanza,
void *userdata);
static char *_make_scram_init_msg(xmpp_conn_t *conn);
struct scram_user_data;
static int _make_scram_init_msg(struct scram_user_data *scram);

static int _handle_missing_features_sasl(xmpp_conn_t *conn, void *userdata);
static int _handle_missing_bind(xmpp_conn_t *conn, void *userdata);
Expand Down Expand Up @@ -243,21 +244,24 @@ _handle_features(xmpp_conn_t *conn, xmpp_stanza_t *stanza, void *userdata)
if (text == NULL)
continue;

if (strcasecmp(text, "PLAIN") == 0)
if (strcasecmp(text, "PLAIN") == 0) {
conn->sasl_support |= SASL_MASK_PLAIN;
else if (strcasecmp(text, "EXTERNAL") == 0 &&
(conn->tls_client_cert || conn->tls_client_key))
} else if (strcasecmp(text, "EXTERNAL") == 0 &&
(conn->tls_client_cert || conn->tls_client_key)) {
conn->sasl_support |= SASL_MASK_EXTERNAL;
else if (strcasecmp(text, "DIGEST-MD5") == 0)
} else if (strcasecmp(text, "DIGEST-MD5") == 0) {
conn->sasl_support |= SASL_MASK_DIGESTMD5;
else if (strcasecmp(text, "SCRAM-SHA-1") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA1;
else if (strcasecmp(text, "SCRAM-SHA-256") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA256;
else if (strcasecmp(text, "SCRAM-SHA-512") == 0)
conn->sasl_support |= SASL_MASK_SCRAMSHA512;
else if (strcasecmp(text, "ANONYMOUS") == 0)
} else if (strcasecmp(text, "ANONYMOUS") == 0) {
conn->sasl_support |= SASL_MASK_ANONYMOUS;
} else {
size_t n;
for (n = 0; n < scram_algs_num; ++n) {
if (strcasecmp(text, scram_algs[n]->scram_name) == 0) {
conn->sasl_support |= scram_algs[n]->mask;
break;
}
}
}

strophe_free(conn->ctx, text);
}
Expand Down Expand Up @@ -439,7 +443,11 @@ static int _handle_digestmd5_rspauth(xmpp_conn_t *conn,
}

struct scram_user_data {
xmpp_conn_t *conn;
int sasl_plus;
char *scram_init;
char *channel_binding;
const char *first_bare;
const struct hash_alg *alg;
};

Expand Down Expand Up @@ -471,8 +479,9 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
if (!challenge)
goto err;

response = sasl_scram(conn->ctx, scram_ctx->alg, challenge,
scram_ctx->scram_init, conn->jid, conn->pass);
response =
sasl_scram(conn->ctx, scram_ctx->alg, scram_ctx->channel_binding,
challenge, scram_ctx->first_bare, conn->jid, conn->pass);
strophe_free(conn->ctx, challenge);
if (!response)
goto err;
Expand Down Expand Up @@ -506,7 +515,8 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
*/
rc = _handle_sasl_result(conn, stanza,
(void *)scram_ctx->alg->scram_name);
strophe_free(conn->ctx, scram_ctx->scram_init);
strophe_free_and_null(conn->ctx, scram_ctx->channel_binding);
strophe_free_and_null(conn->ctx, scram_ctx->scram_init);
strophe_free(conn->ctx, scram_ctx);
}

Expand All @@ -517,33 +527,103 @@ static int _handle_scram_challenge(xmpp_conn_t *conn,
err_free_response:
strophe_free(conn->ctx, response);
err:
strophe_free(conn->ctx, scram_ctx->scram_init);
strophe_free_and_null(conn->ctx, scram_ctx->channel_binding);
strophe_free_and_null(conn->ctx, scram_ctx->scram_init);
strophe_free(conn->ctx, scram_ctx);
disconnect_mem_error(conn);
return 0;
}

static char *_make_scram_init_msg(xmpp_conn_t *conn)
static int _make_scram_init_msg(struct scram_user_data *scram)
{
xmpp_conn_t *conn = scram->conn;
xmpp_ctx_t *ctx = conn->ctx;
size_t message_len;
char *node;
char *message;
char nonce[32];
const void *binding_data;
const char *binding_type;
char *node, *message;
size_t message_len, binding_type_len = 0, binding_data_len;
int l, is_secured = xmpp_conn_is_secured(conn);
/* This buffer must be able to hold:
* "p=<10 bytes binding type>,,<36 bytes binding data>"
* + alignment */
char buf[56];

if (scram->sasl_plus) {
if (!is_secured) {
strophe_error(
ctx, "xmpp",
"SASL: Server requested a -PLUS variant to authenticate, "
"but the connection is not secured. This is an error on "
"the server side we can't do anything about.");
return -1;
}
if (tls_init_channel_binding(conn->tls, &binding_type,
&binding_type_len)) {
return -1;
}
/* directly account for the '=' char in 'p=<binding-type>' */
binding_type_len += 1;
}

node = xmpp_jid_node(ctx, conn->jid);
if (!node) {
return NULL;
return -1;
}
xmpp_rand_nonce(ctx->rand, nonce, sizeof(nonce));
message_len = strlen(node) + strlen(nonce) + 8 + 1;
/* 32 bytes nonce is enough */
xmpp_rand_nonce(ctx->rand, buf, 33);
message_len = strlen(node) + strlen(buf) + 8 + binding_type_len + 1;
message = strophe_alloc(ctx, message_len);
if (message) {
strophe_snprintf(message, message_len, "n,,n=%s,r=%s", node, nonce);
if (!message) {
goto err_node;
}
/* increase length to account for 'y,,', 'n,,' or 'p,,'.
* In the 'p' case the '=' sign has already been accounted for above.
*/
binding_type_len += 3;
if (scram->sasl_plus) {
l = strophe_snprintf(message, message_len, "p=%s,,n=%s,r=%s",
binding_type, node, buf);
} else {
l = strophe_snprintf(message, message_len, "%c,,n=%s,r=%s",
is_secured ? 'y' : 'n', node, buf);
}
if (l < 0 || (size_t)l >= message_len) {
goto err_msg;
}
if (binding_type_len > sizeof(buf)) {
goto err_msg;
}
/* Make `first_bare` point to the 'n' of 'n=<node>' of the
* client-first-message */
scram->first_bare = message + binding_type_len;
memcpy(buf, message, binding_type_len);
if (scram->sasl_plus) {
binding_data =
tls_get_channel_binding_data(conn->tls, &binding_data_len);
if (!binding_data) {
goto err_msg;
}
if (binding_data_len > sizeof(buf) - binding_type_len) {
strophe_error(ctx, "xmpp", "Channel binding data is too long (%zu)",
binding_data_len);
goto err_msg;
}
memcpy(&buf[binding_type_len], binding_data, binding_data_len);
binding_type_len += binding_data_len;
}
scram->channel_binding =
xmpp_base64_encode(ctx, (void *)buf, binding_type_len);
memset(buf, 0, binding_type_len);
strophe_free(ctx, node);
scram->scram_init = message;

return 0;

return message;
err_msg:
strophe_free(ctx, message);
err_node:
strophe_free(ctx, node);
return -1;
}

static xmpp_stanza_t *_make_starttls(xmpp_conn_t *conn)
Expand Down Expand Up @@ -636,7 +716,7 @@ static void _auth(xmpp_conn_t *conn)
return;
}

if (anonjid && conn->sasl_support & SASL_MASK_ANONYMOUS) {
if (anonjid && (conn->sasl_support & SASL_MASK_ANONYMOUS)) {
/* some crap here */
auth = _make_sasl_auth(conn, "ANONYMOUS");
if (!auth) {
Expand Down Expand Up @@ -702,22 +782,26 @@ static void _auth(xmpp_conn_t *conn)
"Password hasn't been set, and SASL ANONYMOUS unsupported.");
xmpp_disconnect(conn);
} else if (conn->sasl_support & SASL_MASK_SCRAM) {
size_t n;
scram_ctx = strophe_alloc(conn->ctx, sizeof(*scram_ctx));
if (conn->sasl_support & SASL_MASK_SCRAMSHA512)
scram_ctx->alg = &scram_sha512;
else if (conn->sasl_support & SASL_MASK_SCRAMSHA256)
scram_ctx->alg = &scram_sha256;
else if (conn->sasl_support & SASL_MASK_SCRAMSHA1)
scram_ctx->alg = &scram_sha1;
memset(scram_ctx, 0, sizeof(*scram_ctx));
for (n = 0; n < scram_algs_num; ++n) {
if (conn->sasl_support & scram_algs[n]->mask) {
scram_ctx->alg = scram_algs[n];
break;
}
}

auth = _make_sasl_auth(conn, scram_ctx->alg->scram_name);
if (!auth) {
disconnect_mem_error(conn);
return;
}

/* don't free scram_init on success */
scram_ctx->scram_init = _make_scram_init_msg(conn);
if (!scram_ctx->scram_init) {
scram_ctx->conn = conn;
scram_ctx->sasl_plus =
scram_ctx->alg->mask & SASL_MASK_SCRAM_PLUS ? 1 : 0;
if (_make_scram_init_msg(scram_ctx)) {
strophe_free(conn->ctx, scram_ctx);
xmpp_stanza_release(auth);
disconnect_mem_error(conn);
Expand Down Expand Up @@ -753,7 +837,7 @@ static void _auth(xmpp_conn_t *conn)

send_stanza(conn, auth, XMPP_QUEUE_STROPHE);

/* SASL SCRAM-SHA-1 was tried, unset flag */
/* SASL algorithm was tried, unset flag */
conn->sasl_support &= ~scram_ctx->alg->mask;
} else if (conn->sasl_support & SASL_MASK_DIGESTMD5) {
auth = _make_sasl_auth(conn, "DIGEST-MD5");
Expand Down
7 changes: 6 additions & 1 deletion src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ struct _xmpp_send_queue_t {
#define SASL_MASK_SCRAMSHA256 (1 << 4)
#define SASL_MASK_SCRAMSHA512 (1 << 5)
#define SASL_MASK_EXTERNAL (1 << 6)
#define SASL_MASK_SCRAMSHA1_PLUS (1 << 7)
#define SASL_MASK_SCRAMSHA256_PLUS (1 << 8)

#define SASL_MASK_SCRAM \
#define SASL_MASK_SCRAM_PLUS \
(SASL_MASK_SCRAMSHA1_PLUS | SASL_MASK_SCRAMSHA256_PLUS)
#define SASL_MASK_SCRAM_WEAK \
(SASL_MASK_SCRAMSHA1 | SASL_MASK_SCRAMSHA256 | SASL_MASK_SCRAMSHA512)
#define SASL_MASK_SCRAM (SASL_MASK_SCRAM_PLUS | SASL_MASK_SCRAM_WEAK)

enum {
XMPP_PORT_CLIENT = 5222,
Expand Down
41 changes: 25 additions & 16 deletions src/sasl.c
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx,
/** generate auth response string for the SASL SCRAM mechanism */
char *sasl_scram(xmpp_ctx_t *ctx,
const struct hash_alg *alg,
const char *channel_binding,
const char *challenge,
const char *first_bare,
const char *jid,
Expand All @@ -398,6 +399,7 @@ char *sasl_scram(xmpp_ctx_t *ctx,
char *result = NULL;
size_t response_len;
size_t auth_len;
int l;

UNUSED(jid);

Expand Down Expand Up @@ -428,37 +430,44 @@ char *sasl_scram(xmpp_ctx_t *ctx,
}
ival = strtol(i, &saveptr, 10);

auth_len = 10 + strlen(r) + strlen(first_bare) + strlen(challenge);
/* "c=<channel_binding>," + r + ",p=" + sign_b64 + '\0' */
response_len = 3 + strlen(channel_binding) + strlen(r) + 3 +
((alg->digest_size + 2) / 3 * 4) + 1;
response = strophe_alloc(ctx, response_len);
if (!response) {
goto out_sval;
}

auth_len = 3 + response_len + strlen(first_bare) + strlen(challenge);
auth = strophe_alloc(ctx, auth_len);
if (!auth) {
goto out_sval;
goto out_response;
}

/* "c=biws," + r + ",p=" + sign_b64 + '\0' */
response_len = 7 + strlen(r) + 3 + ((alg->digest_size + 2) / 3 * 4) + 1;
response = strophe_alloc(ctx, response_len);
if (!response) {
l = strophe_snprintf(response, response_len, "c=%s,%s", channel_binding, r);
if (l < 0 || (size_t)l >= response_len) {
goto out_auth;
}
l = strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare, challenge,
response);
if (l < 0 || (size_t)l >= auth_len) {
goto out_auth;
}

strophe_snprintf(response, response_len, "c=biws,%s", r);
strophe_snprintf(auth, auth_len, "%s,%s,%s", first_bare + 3, challenge,
response);

SCRAM_ClientKey(alg, (uint8_t *)password, strlen(password), (uint8_t *)sval,
sval_len, (uint32_t)ival, key);
SCRAM_ClientSignature(alg, key, (uint8_t *)auth, strlen(auth), sign);
SCRAM_ClientProof(alg, sign, key, sign);
SCRAM_ClientProof(alg, key, sign, sign);

sign_b64 = xmpp_base64_encode(ctx, sign, alg->digest_size);
if (!sign_b64) {
goto out_response;
goto out_auth;
}

/* Check for buffer overflow */
if (strlen(response) + strlen(sign_b64) + 3 + 1 > response_len) {
strophe_free(ctx, sign_b64);
goto out_response;
goto out_auth;
}
strcat(response, ",p=");
strcat(response, sign_b64);
Expand All @@ -467,14 +476,14 @@ char *sasl_scram(xmpp_ctx_t *ctx,
response_b64 =
xmpp_base64_encode(ctx, (unsigned char *)response, strlen(response));
if (!response_b64) {
goto out_response;
goto out_auth;
}
result = response_b64;

out_response:
strophe_free(ctx, response);
out_auth:
strophe_free(ctx, auth);
out_response:
strophe_free(ctx, response);
out_sval:
strophe_free(ctx, sval);
out:
Expand Down
1 change: 1 addition & 0 deletions src/sasl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ char *sasl_digest_md5(xmpp_ctx_t *ctx,
const char *password);
char *sasl_scram(xmpp_ctx_t *ctx,
const struct hash_alg *alg,
const char *channel_binding,
const char *challenge,
const char *first_bare,
const char *jid,
Expand Down
Loading
Loading