diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index 7142924b7a..ddeaec963b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -318,12 +318,16 @@ object Sphinx extends Logging { * @return an encrypted failure packet that can be sent to the destination node. */ def wrap(packet: ByteVector, sharedSecret: ByteVector32): ByteVector = { - require(packet.length == PacketLength, s"invalid error packet length ${packet.length}, must be $PacketLength") + if (packet.length != PacketLength) { + logger.warn(s"invalid error packet length ${packet.length}, must be $PacketLength (malicious or buggy downstream node)") + } val key = generateKey("ammag", sharedSecret) val stream = generateStream(key, PacketLength) logger.debug(s"ammag key: $key") logger.debug(s"error stream: $stream") - packet xor stream + // If we received a packet with an invalid length, we trim and pad to forward a packet with a normal length upstream. + // This is a poor man's attempt at increasing the likelihood of the sender receiving the error. + packet.take(PacketLength).padLeft(PacketLength) xor stream } /** diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index 10182a87db..b235d19c2a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -802,7 +802,6 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { bob2blockchain.expectMsgType[WatchConfirmed] } - test("recv RevokeAndAck (one htlc sent)") { f => import f._ val sender = TestProbe() @@ -1329,6 +1328,20 @@ class NormalStateSpec extends TestkitBaseClass with StateTestsHelperMethods { alice2blockchain.expectMsgType[WatchConfirmed] } + test("recv UpdateFailHtlc (invalid onion error length)") { f => + import f._ + val sender = TestProbe() + val (_, htlc) = addHtlc(50000000 msat, alice, bob, alice2bob, bob2alice) + crossSign(alice, bob, alice2bob, bob2alice) + // Bob receives a failure with a completely invalid onion error (missing mac) + sender.send(bob, CMD_FAIL_HTLC(htlc.id, Left(ByteVector.fill(260)(42)))) + sender.expectMsg("ok") + val fail = bob2alice.expectMsgType[UpdateFailHtlc] + assert(fail.id === htlc.id) + // We should rectify the packet length before forwarding upstream. + assert(fail.reason.length === Sphinx.FailurePacket.PacketLength) + } + test("recv CMD_UPDATE_FEE") { f => import f._ val sender = TestProbe() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 1dbdecffbe..9f08344ecf 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -18,16 +18,16 @@ package fr.acinq.eclair.crypto import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.{UInt64, wire} import fr.acinq.eclair.wire._ +import fr.acinq.eclair.{UInt64, wire} import org.scalatest.FunSuite import scodec.bits._ import scala.util.Success /** - * Created by fabrice on 10/01/17. - */ + * Created by fabrice on 10/01/17. + */ class SphinxSpec extends FunSuite { import Sphinx._ @@ -299,6 +299,20 @@ class SphinxSpec extends FunSuite { } } + test("intermediate node replies with an invalid onion payload length") { + // The error will not be recoverable by the sender, but we must still forward it. + val sharedSecret = ByteVector32(hex"4242424242424242424242424242424242424242424242424242424242424242") + val errors = Seq( + ByteVector.fill(FailurePacket.PacketLength - MacLength)(13), + ByteVector.fill(FailurePacket.PacketLength + MacLength)(13) + ) + + for (error <- errors) { + val wrapped = FailurePacket.wrap(error, sharedSecret) + assert(wrapped.length === FailurePacket.PacketLength) + } + } + test("intermediate node replies with a failure message (reference test vector)") { for (payloads <- Seq(referenceFixedSizePayloads, referenceVariableSizePayloads, variableSizePayloadsFull)) { // route: origin -> node #0 -> node #1 -> node #2 -> node #3 -> node #4