diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala index 1a333c8c78..1d063bd0c2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala @@ -94,7 +94,8 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co makeReader(keyPair) } - def sendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[T, Int] = { + def decodeAndSendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[T, Int] = { + log.debug("decoding {} plaintext messages", plaintextMessages.size) var m: Map[T, Int] = Map() plaintextMessages.foreach(plaintext => Try(codec.decode(plaintext.toBitVector)) match { case Success(Attempt.Successful(DecodeResult(message, _))) => @@ -106,6 +107,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case Failure(t) => log.error(s"cannot deserialize $plaintext: ${t.getMessage}") }) + log.debug("decoded {} messages", m.values.sum) m } @@ -164,29 +166,26 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case Event(Listener(listener), d@WaitingForListenerData(_, dec)) => context.watch(listener) val (dec1, plaintextMessages) = dec.decrypt() - if (plaintextMessages.isEmpty) { + val unackedReceived1 = decodeAndSendToListener(listener, plaintextMessages) + if (unackedReceived1.isEmpty) { + log.debug("no decoded messages, resuming reading") connection ! Tcp.ResumeReading - goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[T], Queue.empty[T]), unackedReceived = Map.empty[T, Int], unackedSent = None) - } else { - log.debug(s"read ${plaintextMessages.size} messages, waiting for readacks") - val unackedReceived = sendToListener(listener, plaintextMessages) - goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[T], Queue.empty[T]), unackedReceived, unackedSent = None) } + goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[T], Queue.empty[T]), unackedReceived = unackedReceived1, unackedSent = None) } } when(Normal) { handleExceptions { case Event(Tcp.Received(data), d: NormalData[T]) => + log.debug("received chunk of size={}", data.size) val (dec1, plaintextMessages) = d.decryptor.copy(buffer = d.decryptor.buffer ++ data).decrypt() - if (plaintextMessages.isEmpty) { + val unackedReceived1 = decodeAndSendToListener(d.listener, plaintextMessages) + if (unackedReceived1.isEmpty) { + log.debug("no decoded messages, resuming reading") connection ! Tcp.ResumeReading - stay() using d.copy(decryptor = dec1) - } else { - log.debug("read {} messages, waiting for readacks", plaintextMessages.size) - val unackedReceived = sendToListener(d.listener, plaintextMessages) - stay() using NormalData(d.encryptor, dec1, d.listener, d.sendBuffer, unackedReceived, d.unackedSent) } + stay() using d.copy(decryptor = dec1, unackedReceived = unackedReceived1) case Event(ReadAck(msg: T), d: NormalData[T]) => // how many occurences of this message are still unacked? @@ -197,11 +196,10 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co if (unackedReceived1.isEmpty) { log.debug("last incoming message was acked, resuming reading") connection ! Tcp.ResumeReading - stay() using d.copy(unackedReceived = unackedReceived1) } else { log.debug("still waiting for readacks, unacked={}", unackedReceived1) - stay() using d.copy(unackedReceived = unackedReceived1) } + stay() using d.copy(unackedReceived = unackedReceived1) case Event(t: T, d: NormalData[T]) => if (d.sendBuffer.normalPriority.size + d.sendBuffer.lowPriority.size >= MAX_BUFFERED) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala index 2e915fdd4a..17415031a0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair.crypto -import java.nio.charset.Charset - import akka.actor.{Actor, ActorLogging, ActorRef, OneForOneStrategy, Props, Stash, SupervisorStrategy, Terminated} import akka.io.Tcp import akka.testkit.{TestActorRef, TestFSMRef, TestProbe} @@ -31,6 +29,7 @@ import scodec.Codec import scodec.bits._ import scodec.codecs._ +import java.nio.charset.Charset import scala.annotation.tailrec import scala.concurrent.duration._ @@ -46,7 +45,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") } - test("succesfull handshake") { + test("successful handshake") { val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val probe2 = TestProbe() @@ -76,7 +75,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be probe1.expectTerminated(pipe) } - test("succesfull handshake with custom serializer") { + test("successful handshake with custom serializer") { case class MyMessage(payload: String) val mycodec: Codec[MyMessage] = ("payload" | scodec.codecs.string32L(Charset.defaultCharset())).as[MyMessage] val pipe = system.actorOf(Props[MyPipe]()) @@ -108,6 +107,52 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be probe1.expectTerminated(pipe) } + test("handle unknown messages") { + sealed trait Message + case object Msg1 extends Message + case object Msg2 extends Message + + val codec1: Codec[Message] = discriminated[Message].by(uint8) + .typecase(1, provide(Msg1)) + + val codec12: Codec[Message] = discriminated[Message].by(uint8) + .typecase(1, provide(Msg1)) + .typecase(2, provide(Msg2)) + + val pipe = system.actorOf(Props[MyPipePull]()) + val probe1 = TestProbe() + val probe2 = TestProbe() + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, codec1)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, codec12)) + pipe ! (initiator, responder) + + awaitCond(initiator.stateName == TransportHandler.WaitingForListener) + awaitCond(responder.stateName == TransportHandler.WaitingForListener) + + initiator ! Listener(probe1.ref) + responder ! Listener(probe2.ref) + + awaitCond(initiator.stateName == TransportHandler.Normal) + awaitCond(responder.stateName == TransportHandler.Normal) + + responder ! Msg1 + probe1.expectMsg(Msg1) + probe1.reply(TransportHandler.ReadAck(Msg1)) + + responder ! Msg2 + probe1.expectNoMessage(2 seconds) // unknown message + + responder ! Msg1 + probe1.expectMsg(Msg1) + probe1.reply(TransportHandler.ReadAck(Msg1)) + + probe1.watch(pipe) + initiator.stop() + responder.stop() + system.stop(pipe) + probe1.expectTerminated(pipe) + } + test("handle messages split in chunks") { val pipe = system.actorOf(Props[MyPipeSplitter]()) val probe1 = TestProbe() @@ -250,6 +295,41 @@ object TransportHandlerSpec { } } + class MyPipePull extends Actor with Stash { + + def receive = { + case (a: ActorRef, b: ActorRef) => + unstashAll() + context watch a + context watch b + context become ready(a, b, aResume = true, bResume = true) + + case msg => stash() + } + + def ready(a: ActorRef, b: ActorRef, aResume: Boolean, bResume: Boolean): Receive = { + case Tcp.Write(data, ack) if sender().path == a.path => + if (bResume) { + b forward Tcp.Received(data) + if (ack != Tcp.NoAck) sender() ! ack + context become ready(a, b, aResume, bResume = false) + } else stash() + case Tcp.ResumeReading if sender().path == b.path => + unstashAll() + context become ready(a, b, aResume, bResume = true) + case Tcp.Write(data, ack) if sender().path == b.path => + if (aResume) { + a forward Tcp.Received(data) + if (ack != Tcp.NoAck) sender() ! ack + context become ready(a, b, aResume = false, bResume) + } else stash() + case Tcp.ResumeReading if sender().path == a.path => + unstashAll() + context become ready(a, b, aResume = true, bResume) + case Terminated(actor) if actor == a || actor == b => context stop self + } + } + // custom supervisor that will stop an actor if it fails class MySupervisor extends Actor { override def supervisorStrategy: SupervisorStrategy = OneForOneStrategy() {