diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala index ac6003d0aa..226d070d6e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala @@ -178,13 +178,20 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co when(Normal) { handleExceptions { case Event(Tcp.Received(data), d: NormalData[T @unchecked]) => + log.debug("received chunk of size={}", data.size) val (dec1, plaintextMessages) = d.decryptor.copy(buffer = d.decryptor.buffer ++ data).decrypt() if (plaintextMessages.isEmpty) { connection ! Tcp.ResumeReading stay() using d.copy(decryptor = dec1) } else { - log.debug("read {} messages, waiting for readacks", plaintextMessages.size) + log.debug("decoding {} raw messages", plaintextMessages.size) val unackedReceived = sendToListener(d.listener, plaintextMessages) + if (unackedReceived.isEmpty) { + log.debug("no decoded messages in this chunk, resuming reading") + connection ! Tcp.ResumeReading + } else { + log.debug("decoded {} messages, waiting for readacks", unackedReceived.size) + } stay() using NormalData(d.encryptor, dec1, d.listener, d.sendBuffer, unackedReceived, d.unackedSent) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala index bc7d35797c..0f0b2a89e2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair.crypto -import java.nio.charset.Charset - import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Stash, SupervisorStrategy, Terminated} import akka.io.Tcp import akka.testkit.{TestActorRef, TestFSMRef, TestProbe} @@ -31,6 +29,7 @@ import scodec.Codec import scodec.bits._ import scodec.codecs._ +import java.nio.charset.Charset import scala.annotation.tailrec import scala.concurrent.duration._ @@ -46,7 +45,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") } - test("succesfull handshake") { + test("successful handshake") { val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val probe2 = TestProbe() @@ -76,7 +75,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be probe1.expectTerminated(pipe) } - test("succesfull handshake with custom serializer") { + test("successful handshake with custom serializer") { case class MyMessage(payload: String) val mycodec: Codec[MyMessage] = ("payload" | scodec.codecs.string32L(Charset.defaultCharset())).as[MyMessage] val pipe = system.actorOf(Props[MyPipe]()) @@ -108,6 +107,52 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be probe1.expectTerminated(pipe) } + test("handle unknown messages") { + sealed trait Message + case object Msg1 extends Message + case object Msg2 extends Message + + val codec1: Codec[Message] = discriminated[Message].by(uint8) + .typecase(1, provide(Msg1)) + + val codec12: Codec[Message] = discriminated[Message].by(uint8) + .typecase(1, provide(Msg1)) + .typecase(2, provide(Msg2)) + + val pipe = system.actorOf(Props[MyPipePull]()) + val probe1 = TestProbe() + val probe2 = TestProbe() + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, codec1)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, codec12)) + pipe ! (initiator, responder) + + awaitCond(initiator.stateName == TransportHandler.WaitingForListener) + awaitCond(responder.stateName == TransportHandler.WaitingForListener) + + initiator ! Listener(probe1.ref) + responder ! Listener(probe2.ref) + + awaitCond(initiator.stateName == TransportHandler.Normal) + awaitCond(responder.stateName == TransportHandler.Normal) + + responder ! Msg1 + probe1.expectMsg(Msg1) + probe1.reply(TransportHandler.ReadAck(Msg1)) + + responder ! Msg2 + probe1.expectNoMessage(2 seconds) // unknown message + + responder ! Msg1 + probe1.expectMsg(Msg1) + probe1.reply(TransportHandler.ReadAck(Msg1)) + + probe1.watch(pipe) + initiator.stop() + responder.stop() + system.stop(pipe) + probe1.expectTerminated(pipe) + } + test("handle messages split in chunks") { val pipe = system.actorOf(Props[MyPipeSplitter]()) val probe1 = TestProbe() @@ -250,6 +295,41 @@ object TransportHandlerSpec { } } + class MyPipePull extends Actor with Stash { + + def receive = { + case (a: ActorRef, b: ActorRef) => + unstashAll() + context watch a + context watch b + context become ready(a, b, aResume = true, bResume = true) + + case msg => stash() + } + + def ready(a: ActorRef, b: ActorRef, aResume: Boolean, bResume: Boolean): Receive = { + case Tcp.Write(data, ack) if sender().path == a.path => + if (bResume) { + b forward Tcp.Received(data) + if (ack != Tcp.NoAck) sender() ! ack + context become ready(a, b, aResume, bResume = false) + } else stash() + case Tcp.ResumeReading if sender().path == b.path => + unstashAll() + context become ready(a, b, aResume, bResume = true) + case Tcp.Write(data, ack) if sender().path == b.path => + if (aResume) { + a forward Tcp.Received(data) + if (ack != Tcp.NoAck) sender() ! ack + context become ready(a, b, aResume = false, bResume) + } else stash() + case Tcp.ResumeReading if sender().path == a.path => + unstashAll() + context become ready(a, b, aResume = true, bResume) + case Terminated(actor) if actor == a || actor == b => context stop self + } + } + // custom supervisor that will stop an actor if it fails class MySupervisor extends Actor { override def supervisorStrategy: SupervisorStrategy = OneForOneStrategy() {