Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle non-standard onion reply size #5698

Merged
merged 3 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions common/sphinx.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

#define BLINDING_FACTOR_SIZE 32

#define ONION_REPLY_SIZE 256

#define RHO_KEYTYPE "rho"

struct hop_params {
Expand Down Expand Up @@ -679,17 +677,36 @@ struct route_step *process_onionpacket(
return step;
}

#if DEVELOPER
unsigned dev_onion_reply_length = 256;
#endif

struct onionreply *create_onionreply(const tal_t *ctx,
const struct secret *shared_secret,
const u8 *failure_msg)
{
size_t msglen = tal_count(failure_msg);
size_t padlen = ONION_REPLY_SIZE - msglen;
size_t padlen;
struct onionreply *reply = tal(ctx, struct onionreply);
u8 *payload = tal_arr(ctx, u8, 0);
struct secret key;
struct hmac hmac;

/* BOLT #4:
* The _erring node_:
* - SHOULD set `pad` such that the `failure_len` plus `pad_len`
* is equal to 256.
* - Note: this value is 118 bytes longer than the longest
* currently-defined message.
*/
const u16 onion_reply_size = IFDEV(dev_onion_reply_length, 256);

/* We never do this currently, but could in future! */
if (msglen > onion_reply_size)
padlen = 0;
else
padlen = onion_reply_size - msglen;

/* BOLT #4:
*
* The node generating the error message (_erring node_) builds a return
Expand All @@ -708,15 +725,8 @@ struct onionreply *create_onionreply(const tal_t *ctx,
towire_u16(&payload, padlen);
towire_pad(&payload, padlen);

/* BOLT #4:
*
* The _erring node_:
* - SHOULD set `pad` such that the `failure_len` plus `pad_len` is
* equal to 256.
* - Note: this value is 118 bytes longer than the longest
* currently-defined message.
*/
assert(tal_count(payload) == ONION_REPLY_SIZE + 4);
/* Two bytes for each length: failure_len and pad_len */
assert(tal_count(payload) == onion_reply_size + 4);

/* BOLT #4:
*
Expand Down Expand Up @@ -763,52 +773,47 @@ u8 *unwrap_onionreply(const tal_t *ctx,
int *origin_index)
{
struct onionreply *r;
struct secret key;
struct hmac hmac;
const u8 *cursor;
u8 *final;
size_t max;
u16 msglen;

if (tal_count(reply->contents) != ONION_REPLY_SIZE + sizeof(hmac) + 4) {
return NULL;
}

r = new_onionreply(tmpctx, reply->contents);
*origin_index = -1;

for (int i = 0; i < numhops; i++) {
struct secret key;
struct hmac hmac, expected_hmac;

/* Since the encryption is just XORing with the cipher
* stream encryption is identical to decryption */
r = wrap_onionreply(tmpctx, &shared_secrets[i], r);

/* Check if the HMAC matches, this means that this is
* the origin */
subkey_from_hmac("um", &shared_secrets[i], &key);
compute_hmac(&key, r->contents + sizeof(hmac.bytes),
tal_count(r->contents) - sizeof(hmac.bytes),
NULL, 0, &hmac);
if (memcmp(hmac.bytes, r->contents, sizeof(hmac.bytes)) == 0) {

cursor = r->contents;
max = tal_count(r->contents);

fromwire_hmac(&cursor, &max, &hmac);
/* Too short. */
if (!cursor)
return NULL;

compute_hmac(&key, cursor, max, NULL, 0, &expected_hmac);
if (hmac_eq(&hmac, &expected_hmac)) {
*origin_index = i;
break;
}
}

/* Didn't find source, it's garbled */
if (*origin_index == -1) {
return NULL;
}

cursor = r->contents + sizeof(hmac);
max = tal_count(r->contents) - sizeof(hmac);
msglen = fromwire_u16(&cursor, &max);

if (msglen > ONION_REPLY_SIZE) {
return NULL;
}

final = tal_arr(ctx, u8, msglen);
if (!fromwire(&cursor, &max, final, msglen))
return tal_free(final);
return final;
return fromwire_tal_arrn(ctx, &cursor, &max, msglen);
}

struct onionpacket *sphinx_decompress(const tal_t *ctx,
Expand Down
3 changes: 3 additions & 0 deletions common/sphinx.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ sphinx_compressed_onion_deserialize(const tal_t *ctx, const u8 *src);
#if DEVELOPER
/* Override to force us to reject valid onion packets */
extern bool dev_fail_process_onionpacket;

/* Override to set custom onion error lengths. */
extern unsigned dev_onion_reply_length;
#endif

#endif /* LIGHTNING_COMMON_SPHINX_H */
5 changes: 5 additions & 0 deletions lightningd/options.c
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,11 @@ static void dev_register_opts(struct lightningd *ld)
opt_register_noarg("--dev-no-ping-timer", opt_set_bool,
&ld->dev_no_ping_timer,
"Don't hang up if we don't get a ping response");
opt_register_arg("--dev-onion-reply-length",
opt_set_uintval,
opt_show_uintval,
&dev_onion_reply_length,
"Send onion errors of custom length");
}
#endif /* DEVELOPER */

Expand Down
5 changes: 3 additions & 2 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def test_channel_state_change_history(node_factory, bitcoind):
assert(history[3]['message'] == "Closing complete")


@pytest.mark.developer("without DEVELOPER=1, gossip v slow")
@pytest.mark.developer("Gossip slow, and we test --dev-onion-reply-length")
def test_htlc_accepted_hook_fail(node_factory):
"""Send payments from l1 to l2, but l2 just declines everything.

Expand All @@ -1053,7 +1053,8 @@ def test_htlc_accepted_hook_fail(node_factory):
"""
l1, l2, l3 = node_factory.line_graph(3, opts=[
{},
{'plugin': os.path.join(os.getcwd(), 'tests/plugins/fail_htlcs.py')},
{'dev-onion-reply-length': 1111,
'plugin': os.path.join(os.getcwd(), 'tests/plugins/fail_htlcs.py')},
{}
], wait_for_announce=True)

Expand Down
2 changes: 2 additions & 0 deletions wire/fromwire.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ u8 *fromwire_tal_arrn(const tal_t *ctx,

arr = tal_arr(ctx, u8, num);
fromwire_u8_array(cursor, max, arr, num);
if (!*cursor)
return tal_free(arr);
return arr;
}

Expand Down