From 17c9eb9d95f47db6be0b544e495cd2248a368d0f Mon Sep 17 00:00:00 2001 From: t-bast Date: Wed, 9 Apr 2025 11:28:04 +0200 Subject: [PATCH 1/3] Clean-up `TransportHandler` The `TransportHandler` is a very old actor that we never refactored much and needed a bit of clean-up. There is no reason to make it generic, it only supports sending lightning messages on the wire. If we ever need to make it generic in the future, we can easily do it, but for simplicity it should only handle lightning messages for now. --- .../eclair/crypto/TransportHandler.scala | 106 +++++++-------- .../eclair/crypto/TransportHandlerSpec.scala | 124 ++++++------------ 2 files changed, 87 insertions(+), 143 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala index 32c8e965dd..c05106d51e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.crypto.ChaCha20Poly1305.ChaCha20Poly1305Error import fr.acinq.eclair.crypto.Noise._ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes -import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, RoutingMessage} +import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, LightningMessage, RoutingMessage} import fr.acinq.eclair.{Diagnostics, FSMDiagnosticActorLogging, Logs, getSimpleClassName} import scodec.bits.ByteVector import scodec.{Attempt, Codec, DecodeResult} @@ -40,7 +40,7 @@ import scala.reflect.ClassTag /** * see BOLT #8 * This class handles the transport layer: - * - initial handshake. upon completion we will have a pair of cipher states (one for encryption, one for decryption) + * - initial handshake. upon completion we will have a pair of cipher states (one for encryption, one for decryption) * - encryption/decryption of messages * * Once the initial handshake has been completed successfully, the handler will create a listener actor with the @@ -50,23 +50,22 @@ import scala.reflect.ClassTag * @param rs remote node static public key (which must be known before we initiate communication) * @param connection actor that represents the other node's */ -class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[T]) extends Actor with FSMDiagnosticActorLogging[TransportHandler.State, TransportHandler.Data] { +class TransportHandler(keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[LightningMessage]) extends Actor with FSMDiagnosticActorLogging[TransportHandler.State, TransportHandler.Data] { // will hold the peer's public key once it is available (we don't know it right away in case of an incoming connection) var remoteNodeId_opt: Option[PublicKey] = rs.map(PublicKey(_)) - val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter + private val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter - def diag(message: T, direction: String): Unit = { - require(direction == "IN" || direction == "OUT") + private def logMessage(message: LightningMessage, direction: String): Unit = { val channelId_opt = Logs.channelId(message) wireLog.mdc(Logs.mdc(LogCategory(message), remoteNodeId_opt, channelId_opt)) if (channelId_opt.isDefined) { // channel-related messages are logged as info - wireLog.info(s"$direction msg={}", message) + wireLog.info("{} msg={}", direction, message) } else { // other messages (e.g. routing gossip) are logged as debug - wireLog.debug(s"$direction msg={}", message) + wireLog.debug("{} msg={}", direction, message) } wireLog.clearMDC() } @@ -79,11 +78,11 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co def buf(message: ByteVector): ByteString = ByteString.fromArray(message.toArray) // it means we initiate the dialog - val isWriter = rs.isDefined + private val isWriter = rs.isDefined context.watch(connection) - val reader = if (isWriter) { + private val reader = if (isWriter) { val state = makeWriter(keyPair, rs.get) val (state1, message, None) = state.write(ByteVector.empty) log.debug(s"sending prefix + $message") @@ -93,12 +92,12 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co makeReader(keyPair) } - def decodeAndSendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[T, Int] = { + private def decodeAndSendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[LightningMessage, Int] = { log.debug("decoding {} plaintext messages", plaintextMessages.size) - var m: Map[T, Int] = Map() + var m = Map.empty[LightningMessage, Int] plaintextMessages.foreach(plaintext => codec.decode(plaintext.bits) match { case Attempt.Successful(DecodeResult(message, _)) => - diag(message, "IN") + logMessage(message, "IN") Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.IN).record(plaintext.size) listener ! message m += (message -> (m.getOrElse(message, 0) + 1)) @@ -132,25 +131,22 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co val nextStateData = WaitingForListenerData(Encryptor(ExtendedCipherState(enc, ck)), Decryptor(ExtendedCipherState(dec, ck), ciphertextLength = None, remainder)) goto(WaitingForListener) using nextStateData - case (writer, _, None) => { + case (writer, _, None) => writer.write(ByteVector.empty) match { - case (reader1, message, None) => { + case (reader1, message, None) => // we're still in the middle of the handshake process and the other end must first received our next // message before they can reply if (remainder.nonEmpty) throw UnexpectedDataDuringHandshake(ByteVector(remainder)) connection ! Tcp.Write(buf(TransportHandler.prefix +: message)) stay() using HandshakeData(reader1, remainder) - } - case (_, message, Some((enc, dec, ck))) => { + case (_, message, Some((enc, dec, ck))) => connection ! Tcp.Write(buf(TransportHandler.prefix +: message)) val remoteNodeId = PublicKey(writer.rs) remoteNodeId_opt = Some(remoteNodeId) context.parent ! HandshakeCompleted(remoteNodeId) val nextStateData = WaitingForListenerData(Encryptor(ExtendedCipherState(enc, ck)), Decryptor(ExtendedCipherState(dec, ck), ciphertextLength = None, remainder)) goto(WaitingForListener) using nextStateData - } } - } } } } @@ -169,13 +165,13 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co log.debug("no decoded messages, resuming reading") connection ! Tcp.ResumeReading } - goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[T], Queue.empty[T]), unackedReceived = unackedReceived1, unackedSent = None) + goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[LightningMessage], Queue.empty[LightningMessage]), unackedReceived = unackedReceived1, unackedSent = None) } } when(Normal) { handleExceptions { - case Event(Tcp.Received(data), d: NormalData[T @unchecked]) => + case Event(Tcp.Received(data), d: NormalData) => log.debug("received chunk of size={}", data.size) val (dec1, plaintextMessages) = d.decryptor.copy(buffer = d.decryptor.buffer ++ data).decrypt() val unackedReceived1 = decodeAndSendToListener(d.listener, plaintextMessages) @@ -185,7 +181,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(decryptor = dec1, unackedReceived = unackedReceived1) - case Event(ReadAck(msg: T), d: NormalData[T @unchecked]) => + case Event(ReadAck(msg: LightningMessage), d: NormalData) => // how many occurrences of this message are still unacked? val remaining = d.unackedReceived.getOrElse(msg, 0) - 1 log.debug("acking message {}", msg) @@ -199,7 +195,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(unackedReceived = unackedReceived1) - case Event(t: T, d: NormalData[T @unchecked]) => + case Event(t: LightningMessage, d: NormalData) => if (d.sendBuffer.normalPriority.size + d.sendBuffer.lowPriority.size >= MAX_BUFFERED) { log.warning("send buffer overrun, closing connection") connection ! PoisonPill @@ -213,7 +209,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(sendBuffer = sendBuffer1) } else { - diag(t, "OUT") + logMessage(t, "OUT") val blob = codec.encode(t).require.toByteVector Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.OUT).record(blob.size) val (enc1, ciphertext) = d.encryptor.encrypt(blob) @@ -221,9 +217,9 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co stay() using d.copy(encryptor = enc1, unackedSent = Some(t)) } - case Event(WriteAck, d: NormalData[T @unchecked]) => - def send(t: T) = { - diag(t, "OUT") + case Event(WriteAck, d: NormalData) => + def send(t: LightningMessage) = { + logMessage(t, "OUT") val blob = codec.encode(t).require.toByteVector Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.OUT).record(blob.size) val (enc1, ciphertext) = d.encryptor.encrypt(blob) @@ -260,7 +256,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case Event(msg, d) => d match { - case n: NormalData[_] => log.warning(s"unhandled message $msg in state normal unackedSent=${n.unackedSent.size} unackedReceived=${n.unackedReceived.size} sendBuffer.lowPriority=${n.sendBuffer.lowPriority.size} sendBuffer.normalPriority=${n.sendBuffer.normalPriority.size}") + case n: NormalData => log.warning(s"unhandled message $msg in state normal unackedSent=${n.unackedSent.size} unackedReceived=${n.unackedReceived.size} sendBuffer.lowPriority=${n.sendBuffer.lowPriority.size} sendBuffer.normalPriority=${n.sendBuffer.normalPriority.size}") case _ => log.warning(s"unhandled message $msg in state ${d.getClass.getSimpleName}") } stay() @@ -273,7 +269,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION), remoteNodeId_opt = remoteNodeId_opt)) { connection ! Tcp.Close // attempts to gracefully close the connection when dying stateData match { - case normal: NormalData[_] => + case normal: NormalData => // NB: we deduplicate on the class name: each class will appear once but there may be many instances (less verbose and gives debug hints) log.info("stopping (unackedReceived={} unackedSent={})", normal.unackedReceived.keys.map(getSimpleClassName).toSet.mkString(","), normal.unackedSent.map(getSimpleClassName)) case _ => @@ -297,8 +293,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case t: Throwable => t match { // for well known crypto error, we don't display the stack trace - case _: InvalidTransportPrefix => log.error(s"crypto error: ${t.getMessage}") - case _: ChaCha20Poly1305Error => log.error(s"crypto error: ${t.getMessage}") + case _: InvalidTransportPrefix | _: ChaCha20Poly1305Error => log.error(s"crypto error: ${t.getMessage}") case _ => log.error(t, "") } throw t @@ -309,19 +304,19 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co object TransportHandler { - def props[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[T]): Props = Props(new TransportHandler(keyPair, rs, connection, codec)) + def props[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[LightningMessage]): Props = Props(new TransportHandler(keyPair, rs, connection, codec)) - val MAX_BUFFERED = 1000000L + private val MAX_BUFFERED = 1000000L // see BOLT #8 // this prefix is prepended to all Noise messages sent during the handshake phase val prefix: Byte = 0x00 + private val prologue = ByteVector.view("lightning".getBytes("UTF-8")) - case class InvalidTransportPrefix(buffer: ByteVector) extends RuntimeException(s"invalid transport prefix first64=${buffer.take(64).toHex}") - - case class UnexpectedDataDuringHandshake(buffer: ByteVector) extends RuntimeException(s"unexpected additional data received during handshake first64=${buffer.take(64).toHex}") - - val prologue = ByteVector.view("lightning".getBytes("UTF-8")) + // @formatter:off + private case class InvalidTransportPrefix(buffer: ByteVector) extends RuntimeException(s"invalid transport prefix first64=${buffer.take(64).toHex}") + private case class UnexpectedDataDuringHandshake(buffer: ByteVector) extends RuntimeException(s"unexpected additional data received during handshake first64=${buffer.take(64).toHex}") + // @formatter:on /** * See BOLT #8: during the handshake phase we are expecting 3 messages of 50, 50 and 66 bytes (including the prefix) @@ -329,17 +324,17 @@ object TransportHandler { * @param reader handshake state reader * @return the size of the message the reader is expecting */ - def expectedLength(reader: Noise.HandshakeStateReader) = reader.messages.length match { + private def expectedLength(reader: Noise.HandshakeStateReader): Int = reader.messages.length match { case 3 | 2 => 50 case 1 => 66 } - def makeWriter(localStatic: KeyPair, remoteStatic: ByteVector) = Noise.HandshakeState.initializeWriter( + private def makeWriter(localStatic: KeyPair, remoteStatic: ByteVector): HandshakeStateWriter = Noise.HandshakeState.initializeWriter( Noise.handshakePatternXK, prologue, localStatic, KeyPair(ByteVector.empty, ByteVector.empty), remoteStatic, ByteVector.empty, Noise.Secp256k1DHFunctions, Noise.Chacha20Poly1305CipherFunctions, Noise.SHA256HashFunctions) - def makeReader(localStatic: KeyPair) = Noise.HandshakeState.initializeReader( + private def makeReader(localStatic: KeyPair): HandshakeStateReader = Noise.HandshakeState.initializeReader( Noise.handshakePatternXK, prologue, localStatic, KeyPair(ByteVector.empty, ByteVector.empty), ByteVector.empty, ByteVector.empty, Noise.Secp256k1DHFunctions, Noise.Chacha20Poly1305CipherFunctions, Noise.SHA256HashFunctions) @@ -351,37 +346,32 @@ object TransportHandler { * @param ck chaining key */ case class ExtendedCipherState(cs: CipherState, ck: ByteVector) extends CipherState { - override def cipher: CipherFunctions = cs.cipher - - override def hasKey: Boolean = cs.hasKey + override val cipher: CipherFunctions = cs.cipher + override val hasKey: Boolean = cs.hasKey override def encryptWithAd(ad: ByteVector, plaintext: ByteVector): (CipherState, ByteVector) = { cs match { case UninitializedCipherState(_) => (this, plaintext) - case InitializedCipherState(k, n, _) if n == 999 => { + case InitializedCipherState(k, n, _) if n == 999 => val (_, ciphertext) = cs.encryptWithAd(ad, plaintext) val (ck1, k1) = SHA256HashFunctions.hkdf(ck, k) (this.copy(cs = cs.initializeKey(k1), ck = ck1), ciphertext) - } - case InitializedCipherState(_, n, _) => { + case _: InitializedCipherState => val (cs1, ciphertext) = cs.encryptWithAd(ad, plaintext) (this.copy(cs = cs1), ciphertext) - } } } override def decryptWithAd(ad: ByteVector, ciphertext: ByteVector): (CipherState, ByteVector) = { cs match { case UninitializedCipherState(_) => (this, ciphertext) - case InitializedCipherState(k, n, _) if n == 999 => { + case InitializedCipherState(k, n, _) if n == 999 => val (_, plaintext) = cs.decryptWithAd(ad, ciphertext) val (ck1, k1) = SHA256HashFunctions.hkdf(ck, k) (this.copy(cs = cs.initializeKey(k1), ck = ck1), plaintext) - } - case InitializedCipherState(_, n, _) => { + case _: InitializedCipherState => val (cs1, plaintext) = cs.decryptWithAd(ad, ciphertext) (this.copy(cs = cs1), plaintext) - } } } } @@ -438,23 +428,23 @@ object TransportHandler { // @formatter:off sealed trait State - case object Handshake extends State + private case object Handshake extends State case object WaitingForListener extends State case object Normal extends State sealed trait Data - case class HandshakeData(reader: Noise.HandshakeStateReader, buffer: ByteString = ByteString.empty) extends Data - case class WaitingForListenerData(encryptor: Encryptor, decryptor: Decryptor) extends Data - case class NormalData[T](encryptor: Encryptor, decryptor: Decryptor, listener: ActorRef, sendBuffer: SendBuffer[T], unackedReceived: Map[T, Int], unackedSent: Option[T]) extends Data + private case class HandshakeData(reader: Noise.HandshakeStateReader, buffer: ByteString = ByteString.empty) extends Data + private case class WaitingForListenerData(encryptor: Encryptor, decryptor: Decryptor) extends Data + private case class NormalData(encryptor: Encryptor, decryptor: Decryptor, listener: ActorRef, sendBuffer: SendBuffer, unackedReceived: Map[LightningMessage, Int], unackedSent: Option[LightningMessage]) extends Data - case class SendBuffer[T](normalPriority: Queue[T], lowPriority: Queue[T]) + private case class SendBuffer(normalPriority: Queue[LightningMessage], lowPriority: Queue[LightningMessage]) case class Listener(listener: ActorRef) case class HandshakeCompleted(remoteNodeId: PublicKey) case class ReadAck(msg: Any) extends RemoteTypes - case object WriteAck extends Tcp.Event + private case object WriteAck extends Tcp.Event // @formatter:on } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala index 0f0b2a89e2..725233c9d6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala @@ -22,14 +22,14 @@ import akka.testkit.{TestActorRef, TestFSMRef, TestProbe} import fr.acinq.eclair.TestKitBaseClass import fr.acinq.eclair.crypto.Noise.{Chacha20Poly1305CipherFunctions, CipherState} import fr.acinq.eclair.crypto.TransportHandler.{Encryptor, ExtendedCipherState, Listener} -import fr.acinq.eclair.wire.protocol.CommonCodecs +import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{lightningMessageCodec, pingCodec} +import fr.acinq.eclair.wire.protocol.{LightningMessage, Ping, Pong} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuiteLike import scodec.Codec import scodec.bits._ import scodec.codecs._ -import java.nio.charset.Charset import scala.annotation.tailrec import scala.concurrent.duration._ @@ -38,19 +38,19 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be import TransportHandlerSpec._ object Initiator { - val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"1111111111111111111111111111111111111111111111111111111111111111") + val s: Noise.KeyPair = Noise.Secp256k1DHFunctions.generateKeyPair(hex"1111111111111111111111111111111111111111111111111111111111111111") } object Responder { - val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") + val s: Noise.KeyPair = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") } test("successful handshake") { val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, lightningMessageCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -62,43 +62,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - initiator.tell(ByteVector("hello".getBytes), probe1.ref) - probe2.expectMsg(ByteVector("hello".getBytes)) + initiator.tell(Ping(1105, ByteVector("hello".getBytes)), probe1.ref) + probe2.expectMsg(Ping(1105, ByteVector("hello".getBytes))) - responder.tell(ByteVector("bonjour".getBytes), probe2.ref) - probe1.expectMsg(ByteVector("bonjour".getBytes)) - - probe1.watch(pipe) - initiator.stop() - responder.stop() - system.stop(pipe) - probe1.expectTerminated(pipe) - } - - test("successful handshake with custom serializer") { - case class MyMessage(payload: String) - val mycodec: Codec[MyMessage] = ("payload" | scodec.codecs.string32L(Charset.defaultCharset())).as[MyMessage] - val pipe = system.actorOf(Props[MyPipe]()) - val probe1 = TestProbe() - val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, mycodec)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, mycodec)) - pipe ! (initiator, responder) - - awaitCond(initiator.stateName == TransportHandler.WaitingForListener) - awaitCond(responder.stateName == TransportHandler.WaitingForListener) - - initiator ! Listener(probe1.ref) - responder ! Listener(probe2.ref) - - awaitCond(initiator.stateName == TransportHandler.Normal) - awaitCond(responder.stateName == TransportHandler.Normal) - - initiator.tell(MyMessage("hello"), probe1.ref) - probe2.expectMsg(MyMessage("hello")) - - responder.tell(MyMessage("bonjour"), probe2.ref) - probe1.expectMsg(MyMessage("bonjour")) + responder.tell(Pong(ByteVector("bonjour".getBytes)), probe2.ref) + probe1.expectMsg(Pong(ByteVector("bonjour".getBytes))) probe1.watch(pipe) initiator.stop() @@ -108,22 +76,14 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be } test("handle unknown messages") { - sealed trait Message - case object Msg1 extends Message - case object Msg2 extends Message - - val codec1: Codec[Message] = discriminated[Message].by(uint8) - .typecase(1, provide(Msg1)) - - val codec12: Codec[Message] = discriminated[Message].by(uint8) - .typecase(1, provide(Msg1)) - .typecase(2, provide(Msg2)) + val incompleteCodec: Codec[LightningMessage] = discriminated[LightningMessage].by(uint16) + .typecase(18, pingCodec) val pipe = system.actorOf(Props[MyPipePull]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, codec1)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, codec12)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, incompleteCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -135,16 +95,18 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - responder ! Msg1 - probe1.expectMsg(Msg1) - probe1.reply(TransportHandler.ReadAck(Msg1)) + val msg1 = Ping(130, hex"deadbeef") + responder ! msg1 + probe1.expectMsg(msg1) + probe1.reply(TransportHandler.ReadAck(msg1)) - responder ! Msg2 + responder ! Pong(hex"deadbeef") probe1.expectNoMessage(2 seconds) // unknown message - responder ! Msg1 - probe1.expectMsg(Msg1) - probe1.reply(TransportHandler.ReadAck(Msg1)) + val msg2 = Ping(42, hex"beefdead") + responder ! msg2 + probe1.expectMsg(msg2) + probe1.reply(TransportHandler.ReadAck(msg2)) probe1.watch(pipe) initiator.stop() @@ -157,8 +119,8 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val pipe = system.actorOf(Props[MyPipeSplitter]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, lightningMessageCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -170,11 +132,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - initiator.tell(ByteVector("hello".getBytes), probe1.ref) - probe2.expectMsg(ByteVector("hello".getBytes)) + initiator.tell(Ping(187, ByteVector("hello".getBytes)), probe1.ref) + probe2.expectMsg(Ping(187, ByteVector("hello".getBytes))) - responder.tell(ByteVector("bonjour".getBytes), probe2.ref) - probe1.expectMsg(ByteVector("bonjour".getBytes)) + responder.tell(Pong(ByteVector("bonjour".getBytes)), probe2.ref) + probe1.expectMsg(Pong(ByteVector("bonjour".getBytes))) probe1.watch(pipe) initiator.stop() @@ -187,11 +149,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val supervisor = TestActorRef(Props(new MySupervisor())) - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, CommonCodecs.varsizebinarydata), supervisor, "ini") - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata), supervisor, "res") + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, lightningMessageCodec), supervisor, "ini") + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec), supervisor, "res") probe1.watch(responder) pipe ! (initiator, responder) - + // We automatically disconnect after a while if the handshake doesn't succeed. probe1.expectTerminated(responder, 3 seconds) } @@ -219,9 +181,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be */ val ck = hex"919219dbb2920afa8db80f9a51787a840bcf111ed8d588caf9ab4be716e42b01" val sk = hex"969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9" - val rk = hex"bb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442" val enc = ExtendedCipherState(CipherState(sk, Chacha20Poly1305CipherFunctions), ck) - val dec = ExtendedCipherState(CipherState(rk, Chacha20Poly1305CipherFunctions), ck) @tailrec def loop(cs: Encryptor, count: Int, acc: Vector[ByteVector] = Vector.empty[ByteVector]): Vector[ByteVector] = { @@ -244,15 +204,13 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be object TransportHandlerSpec { class MyPipe extends Actor with Stash with ActorLogging { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef): Receive = { @@ -267,15 +225,13 @@ object TransportHandlerSpec { } class MyPipeSplitter extends Actor with Stash { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef): Receive = { @@ -296,15 +252,13 @@ object TransportHandlerSpec { } class MyPipePull extends Actor with Stash { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b, aResume = true, bResume = true) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef, aResume: Boolean, bResume: Boolean): Receive = { @@ -336,7 +290,7 @@ object TransportHandlerSpec { case _ => SupervisorStrategy.stop } - def receive = { + def receive: Receive = { case _ => () } } From 5dbd63a5d0f7d69399fbf948d5ccffbdf93b3e2d Mon Sep 17 00:00:00 2001 From: t-bast Date: Tue, 13 May 2025 16:29:53 +0200 Subject: [PATCH 2/3] Send splice `commit_sig` as a batch We introduce a `CommitSigBatch` class to group `commit_sig` messages when splice transactions are pending. We use this class to ensure that all the `commit_sig` messages in the batch are sent together to our peer, without any other messages in-between. --- .../fr/acinq/eclair/channel/Commitments.scala | 15 +++-- .../fr/acinq/eclair/channel/Helpers.scala | 8 +-- .../fr/acinq/eclair/channel/fsm/Channel.scala | 4 +- .../fr/acinq/eclair/io/PeerConnection.scala | 12 ++-- .../wire/protocol/LightningMessageTypes.scala | 18 +++++- .../ChannelStateTestsHelperMethods.scala | 34 +++++------ .../states/e/NormalSplicesStateSpec.scala | 58 ++++++++----------- .../acinq/eclair/io/PeerConnectionSpec.scala | 17 +++++- 8 files changed, 93 insertions(+), 73 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index 721612c515..b328c07ad9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -993,7 +993,7 @@ case class Commitments(params: ChannelParams, } } - def sendCommit(channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, Seq[CommitSig])] = { + def sendCommit(channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, CommitSigs)] = { remoteNextCommitInfo match { case Right(_) if !changes.localHasChanges => Left(CannotSignWithoutChanges(channelId)) case Right(remoteNextPerCommitmentPoint) => @@ -1007,21 +1007,24 @@ case class Commitments(params: ChannelParams, active = active1, remoteNextCommitInfo = Left(WaitForRev(localCommitIndex)) ) - Right(commitments1, sigs) + Right(commitments1, CommitSigs(sigs)) case Left(_) => Left(CannotSignBeforeRevocation(channelId)) } } - def receiveCommit(commits: Seq[CommitSig], channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, RevokeAndAck)] = { + def receiveCommit(commitSigs: CommitSigs, channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, RevokeAndAck)] = { // We may receive more commit_sig than the number of active commitments, because there can be a race where we send // splice_locked while our peer is sending us a batch of commit_sig. When that happens, we simply need to discard // the commit_sig that belong to commitments we deactivated. - if (commits.size < active.size) { - return Left(CommitSigCountMismatch(channelId, active.size, commits.size)) + val sigs = commitSigs match { + case batch: CommitSigBatch if batch.batchSize < active.size => return Left(CommitSigCountMismatch(channelId, active.size, batch.batchSize)) + case batch: CommitSigBatch => batch.messages + case _: CommitSig if active.size > 1 => return Left(CommitSigCountMismatch(channelId, active.size, 1)) + case commitSig: CommitSig => Seq(commitSig) } val commitKeys = LocalCommitmentKeys(params, channelKeys, localCommitIndex + 1) // Signatures are sent in order (most recent first), calling `zip` will drop trailing sigs that are for deactivated/pruned commitments. - val active1 = active.zip(commits).map { case (commitment, commit) => + val active1 = active.zip(sigs).map { case (commitment, commit) => commitment.receiveCommit(params, channelKeys, commitKeys, changes, commit) match { case Left(f) => return Left(f) case Right(commitment1) => commitment1 diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index 65afc6e37f..e7048a0d7f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -478,14 +478,14 @@ object Helpers { // we just sent a new commit_sig but they didn't receive it // we resend the same updates and the same sig, and preserve the same ordering val signedUpdates = commitments.changes.localChanges.signed - val commitSigs = commitments.active.flatMap(_.nextRemoteCommit_opt).map(_.sig) + val commitSigs = CommitSigs(commitments.active.flatMap(_.nextRemoteCommit_opt).map(_.sig)) retransmitRevocation_opt match { case None => - SyncResult.Success(retransmit = signedUpdates ++ commitSigs) + SyncResult.Success(retransmit = signedUpdates :+ commitSigs) case Some(revocation) if commitments.localCommitIndex > waitingForRevocation.sentAfterLocalCommitIndex => - SyncResult.Success(retransmit = signedUpdates ++ commitSigs ++ Seq(revocation)) + SyncResult.Success(retransmit = signedUpdates :+ commitSigs :+ revocation) case Some(revocation) => - SyncResult.Success(retransmit = Seq(revocation) ++ signedUpdates ++ commitSigs) + SyncResult.Success(retransmit = revocation +: signedUpdates :+ commitSigs) } case Left(_) if remoteChannelReestablish.nextLocalCommitmentNumber == (commitments.nextRemoteCommitIndex + 1) => // we just sent a new commit_sig, they have received it but we haven't received their revocation diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index 44241283ba..470ef7e2ca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -3038,11 +3038,11 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall */ /** For splices we will send one commit_sig per active commitments. */ - private def aggregateSigs(commit: CommitSig): Option[Seq[CommitSig]] = { + private def aggregateSigs(commit: CommitSig): Option[CommitSigs] = { sigStash = sigStash :+ commit log.debug("received sig for batch of size={}", commit.batchSize) if (sigStash.size == commit.batchSize) { - val sigs = sigStash + val sigs = CommitSigs(sigStash) sigStash = Nil Some(sigs) } else { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index ffcf92cac7..4ecc3983dc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.io import akka.actor.{ActorRef, FSM, OneForOneStrategy, PoisonPill, Props, Stash, SupervisorStrategy, Terminated} import akka.event.Logging.MDC -import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32} +import fr.acinq.bitcoin.scalacompat.BlockHash import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.crypto.Noise.KeyPair @@ -28,7 +28,7 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{FSMDiagnosticActorLogging, FeatureCompatibilityResult, Features, InitFeature, Logs, TimestampMilli, TimestampSecond} +import fr.acinq.eclair.{FSMDiagnosticActorLogging, Features, InitFeature, Logs, TimestampMilli, TimestampSecond} import scodec.Attempt import scodec.bits.ByteVector @@ -206,11 +206,13 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A stay() case Event(msg: LightningMessage, d: ConnectedData) if sender() != d.transport => // if the message doesn't originate from the transport, it is an outgoing message - d.transport forward msg + msg match { + case batch: CommitSigBatch => batch.messages.foreach(msg => d.transport forward msg) + case msg => d.transport forward msg + } msg match { // If we send any channel management message to this peer, the connection should be persistent. - case _: ChannelMessage if !d.isPersistent => - stay() using d.copy(isPersistent = true) + case _: ChannelMessage if !d.isPersistent => stay() using d.copy(isPersistent = true) case _ => stay() } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 11034b42d0..331dfc87c7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -427,13 +427,29 @@ case class UpdateFailMalformedHtlc(channelId: ByteVector32, failureCode: Int, tlvStream: TlvStream[UpdateFailMalformedHtlcTlv] = TlvStream.empty) extends HtlcMessage with UpdateMessage with HasChannelId with HtlcFailureMessage +/** + * [[CommitSig]] can either be sent individually or as part of a batch. When sent in a batch (which happens when there + * are pending splice transactions), we treat the whole batch as a single lightning message and group them on the wire. + */ +sealed trait CommitSigs extends HtlcMessage with HasChannelId + +object CommitSigs { + def apply(sigs: Seq[CommitSig]): CommitSigs = if (sigs.size == 1) sigs.head else CommitSigBatch(sigs) +} + case class CommitSig(channelId: ByteVector32, signature: ByteVector64, htlcSignatures: List[ByteVector64], - tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends HtlcMessage with HasChannelId { + tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs { val batchSize: Int = tlvStream.get[CommitSigTlv.BatchTlv].map(_.size).getOrElse(1) } +case class CommitSigBatch(messages: Seq[CommitSig]) extends CommitSigs { + require(messages.map(_.channelId).toSet.size == 1, "commit_sig messages in a batch must be for the same channel") + val channelId: ByteVector32 = messages.head.channelId + val batchSize: Int = messages.size +} + case class RevokeAndAck(channelId: ByteVector32, perCommitmentSecret: PrivateKey, nextPerCommitmentPoint: PublicKey, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 48ebba08a2..b7f8e23f9f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher._ import fr.acinq.eclair.blockchain.fee.{FeeratePerKw, FeeratesPerKw} -import fr.acinq.eclair.blockchain.{DummyOnChainWallet, OnChainWallet, OnChainPubkeyCache, SingleKeyOnChainWallet} +import fr.acinq.eclair.blockchain.{DummyOnChainWallet, OnChainPubkeyCache, OnChainWallet, SingleKeyOnChainWallet} import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.publish.TxPublisher @@ -463,31 +463,23 @@ trait ChannelStateTestsBase extends Assertions with Eventually { val rHasChanges = r.stateData.asInstanceOf[ChannelDataWithCommitments].commitments.changes.localHasChanges s ! CMD_SIGN(Some(sender.ref)) sender.expectMsgType[RES_SUCCESS[CMD_SIGN]] - var sigs2r = 0 - var batchSize = 0 - do { - val sig = s2r.expectMsgType[CommitSig] - s2r.forward(r) - sigs2r += 1 - batchSize = sig.batchSize - } while (sigs2r < batchSize) + s2r.expectMsgType[CommitSigs] match { + case sig: CommitSig => s2r.forward(r, sig) + case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig)) + } r2s.expectMsgType[RevokeAndAck] r2s.forward(s) - var sigr2s = 0 - do { - r2s.expectMsgType[CommitSig] - r2s.forward(s) - sigr2s += 1 - } while (sigr2s < batchSize) + r2s.expectMsgType[CommitSigs] match { + case sig: CommitSig => r2s.forward(s, sig) + case batch: CommitSigBatch => batch.messages.foreach(sig => r2s.forward(s, sig)) + } s2r.expectMsgType[RevokeAndAck] s2r.forward(r) if (rHasChanges) { - sigs2r = 0 - do { - s2r.expectMsgType[CommitSig] - s2r.forward(r) - sigs2r += 1 - } while (sigs2r < batchSize) + s2r.expectMsgType[CommitSigs] match { + case sig: CommitSig => s2r.forward(r, sig) + case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig)) + } r2s.expectMsgType[RevokeAndAck] r2s.forward(s) eventually { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala index c51e43176b..e77b87d36e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala @@ -1517,20 +1517,14 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) alice ! CMD_SIGN() - val sigA1 = alice2bob.expectMsgType[CommitSig] - assert(sigA1.batchSize == 2) - alice2bob.forward(bob) - val sigA2 = alice2bob.expectMsgType[CommitSig] - assert(sigA2.batchSize == 2) - alice2bob.forward(bob) + val sigsA = alice2bob.expectMsgType[CommitSigBatch] + assert(sigsA.batchSize == 2) + sigsA.messages.foreach(sig => alice2bob.forward(bob, sig)) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - val sigB1 = bob2alice.expectMsgType[CommitSig] - assert(sigB1.batchSize == 2) - bob2alice.forward(alice) - val sigB2 = bob2alice.expectMsgType[CommitSig] - assert(sigB2.batchSize == 2) - bob2alice.forward(alice) + val sigsB = bob2alice.expectMsgType[CommitSigBatch] + assert(sigsB.batchSize == 2) + sigsB.messages.foreach(sig => bob2alice.forward(alice, sig)) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1546,23 +1540,20 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) alice ! CMD_SIGN() - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) + assert(alice2bob.expectMsgType[CommitSigBatch].batchSize == 2) // Bob disconnects before receiving Alice's commit_sig. disconnect(f) reconnect(f) alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - alice2bob.forward(bob) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - alice2bob.forward(bob) + val sigsA = alice2bob.expectMsgType[CommitSigBatch] + assert(sigsA.batchSize == 2) + sigsA.messages.foreach(sig => alice2bob.forward(bob, sig)) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 2) - bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 2) - bob2alice.forward(alice) + val sigsB = bob2alice.expectMsgType[CommitSigBatch] + assert(sigsB.batchSize == 2) + sigsB.messages.foreach(sig => bob2alice.forward(alice, sig)) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1679,14 +1670,15 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.forward(bob, spliceLockedAlice) val (preimage, htlc) = addHtlc(20_000_000 msat, alice, bob, alice2bob, bob2alice) alice ! CMD_SIGN() - val commitSigsAlice = (1 to 3).map(_ => alice2bob.expectMsgType[CommitSig]) - alice2bob.forward(bob, commitSigsAlice(0)) + val commitSigsAlice = alice2bob.expectMsgType[CommitSigBatch] + assert(commitSigsAlice.batchSize == 3) + alice2bob.forward(bob, commitSigsAlice.messages(0)) bob ! WatchPublishedTriggered(spliceTx2) val spliceLockedBob = bob2alice.expectMsgType[SpliceLocked] assert(spliceLockedBob.fundingTxId == spliceTx2.txid) bob2alice.forward(alice, spliceLockedBob) - alice2bob.forward(bob, commitSigsAlice(1)) - alice2bob.forward(bob, commitSigsAlice(2)) + alice2bob.forward(bob, commitSigsAlice.messages(1)) + alice2bob.forward(bob, commitSigsAlice.messages(2)) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) assert(bob2alice.expectMsgType[CommitSig].batchSize == 1) @@ -2858,7 +2850,7 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik // Alice adds another HTLC that isn't signed by Bob. val (_, htlcOut2) = addHtlc(15_000_000 msat, alice, bob, alice2bob, bob2alice) alice ! CMD_SIGN() - alice2bob.expectMsgType[CommitSig] // Bob ignores Alice's message + alice2bob.expectMsgType[CommitSigBatch] // Bob ignores Alice's message // The first splice transaction confirms. alice ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx1) @@ -3341,15 +3333,15 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik // Bob is waiting to sign its outgoing HTLC before sending stfu. bob2alice.expectNoMessage(100 millis) bob ! CMD_SIGN() - (0 until 3).foreach { _ => - bob2alice.expectMsgType[CommitSig] - bob2alice.forward(alice) + inside(bob2alice.expectMsgType[CommitSigBatch]) { batch => + assert(batch.batchSize == 3) + batch.messages.foreach(sig => bob2alice.forward(alice, sig)) } alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) - (0 until 3).foreach { _ => - alice2bob.expectMsgType[CommitSig] - alice2bob.forward(bob) + inside(alice2bob.expectMsgType[CommitSigBatch]) { batch => + assert(batch.batchSize == 3) + batch.messages.foreach(sig => alice2bob.forward(bob, sig)) } bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index 4b9b1a152a..eabae39ee0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -21,7 +21,7 @@ import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} -import fr.acinq.eclair.Features.{BasicMultiPartPayment, ChannelRangeQueries, ChannelType, PaymentSecret, StaticRemoteKey, VariableLengthOnion} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.ConnectionDown @@ -333,6 +333,21 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi transport.expectNoMessage(10 / transport.testKitSettings.TestTimeFactor seconds) // we don't want dilated time here } + test("send batch of commit_sig messages") { f => + import f._ + val probe = TestProbe() + connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) + val channelId = randomBytes32() + val commitSigs = Seq( + CommitSig(channelId, randomBytes64(), Nil), + CommitSig(channelId, randomBytes64(), Nil), + CommitSig(channelId, randomBytes64(), Nil), + ) + probe.send(peerConnection, CommitSigBatch(commitSigs)) + commitSigs.foreach(commitSig => transport.expectMsg(commitSig)) + transport.expectNoMessage(100 millis) + } + test("react to peer's bad behavior") { f => import f._ val probe = TestProbe() From 71fb4fad35169689a7bcb47d032333cb4da629b4 Mon Sep 17 00:00:00 2001 From: t-bast Date: Thu, 15 May 2025 17:05:39 +0200 Subject: [PATCH 3/3] Batch incoming `commit_sig` in `PeerConnection` We move the incoming `commit_sig` batching logic outside of the channel and into the `PeerConnection` instead. This slightly simplifies the channel FSM and its tests, since the `PeerConnection` actor is simpler. We unfortunately cannot easily do this in the `TransportHandler` because of our buffered read of the encrypted messages, which may split batches and make it more complex to correctly group messages. --- .../fr/acinq/eclair/channel/Commitments.scala | 2 +- .../fr/acinq/eclair/channel/fsm/Channel.scala | 204 ++++++++---------- .../fr/acinq/eclair/io/PeerConnection.scala | 61 +++++- .../wire/protocol/LightningMessageTypes.scala | 4 +- .../ChannelStateTestsHelperMethods.scala | 18 +- .../states/e/NormalSplicesStateSpec.scala | 18 +- .../acinq/eclair/io/PeerConnectionSpec.scala | 80 +++++++ 7 files changed, 242 insertions(+), 145 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index b328c07ad9..e2f4a89d1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -1159,7 +1159,7 @@ case class Commitments(params: ChannelParams, case ChannelSpendSignature.IndividualSignature(latestRemoteSig) => latestRemoteSig == commitSig.signature case ChannelSpendSignature.PartialSignatureWithNonce(_, _) => ??? } - params.channelFeatures.hasFeature(Features.DualFunding) && commitSig.batchSize == 1 && isLatestSig + params.channelFeatures.hasFeature(Features.DualFunding) && isLatestSig } def localFundingSigs(fundingTxId: TxId): Option[TxSignatures] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index 470ef7e2ca..ba5bfeb664 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -211,8 +211,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall // choose to not make this an Option (that would be None before the first connection), and instead embrace the fact // that the active connection may point to dead letters at all time var activeConnection = context.system.deadLetters - // we aggregate sigs for splices before processing - var sigStash = Seq.empty[CommitSig] // we stash announcement_signatures if we receive them earlier than expected var announcementSigsStash = Map.empty[RealShortChannelId, AnnouncementSignatures] // we record the announcement_signatures messages we already sent to avoid unnecessary retransmission @@ -581,77 +579,75 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall stay() } - case Event(commit: CommitSig, d: DATA_NORMAL) => - aggregateSigs(commit) match { - case Some(sigs) => - d.spliceStatus match { - case s: SpliceStatus.SpliceInProgress => - log.debug("received their commit_sig, deferring message") - stay() using d.copy(spliceStatus = s.copy(remoteCommitSig = Some(commit))) - case SpliceStatus.SpliceAborted => - log.warning("received commit_sig after sending tx_abort, they probably sent it before receiving our tx_abort, ignoring...") - stay() - case SpliceStatus.SpliceWaitingForSigs(signingSession) => - signingSession.receiveCommitSig(d.commitments.params, channelKeys, commit, nodeParams.currentBlockHeight) match { - case Left(f) => - rollbackFundingAttempt(signingSession.fundingTx.tx, Nil) - stay() using d.copy(spliceStatus = SpliceStatus.SpliceAborted) sending TxAbort(d.channelId, f.getMessage) - case Right(signingSession1) => signingSession1 match { - case signingSession1: InteractiveTxSigningSession.WaitingForSigs => - // In theory we don't have to store their commit_sig here, as they would re-send it if we disconnect, but - // it is more consistent with the case where we send our tx_signatures first. - val d1 = d.copy(spliceStatus = SpliceStatus.SpliceWaitingForSigs(signingSession1)) - stay() using d1 storing() - case signingSession1: InteractiveTxSigningSession.SendingSigs => - // We don't have their tx_sigs, but they have ours, and could publish the funding tx without telling us. - // That's why we move on immediately to the next step, and will update our unsigned funding tx when we - // receive their tx_sigs. - val minDepth_opt = d.commitments.params.minDepth(nodeParams.channelConf.minDepth) - watchFundingConfirmed(signingSession.fundingTx.txId, minDepth_opt, delay_opt = None) - val commitments1 = d.commitments.add(signingSession1.commitment) - val d1 = d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NoSplice) - stay() using d1 storing() sending signingSession1.localSigs calling endQuiescence(d1) - } + case Event(commit: CommitSigs, d: DATA_NORMAL) => + (d.spliceStatus, commit) match { + case (s: SpliceStatus.SpliceInProgress, sig: CommitSig) => + log.debug("received their commit_sig, deferring message") + stay() using d.copy(spliceStatus = s.copy(remoteCommitSig = Some(sig))) + case (SpliceStatus.SpliceAborted, sig: CommitSig) => + log.warning("received commit_sig after sending tx_abort, they probably sent it before receiving our tx_abort, ignoring...") + stay() + case (SpliceStatus.SpliceWaitingForSigs(signingSession), sig: CommitSig) => + signingSession.receiveCommitSig(d.commitments.params, channelKeys, sig, nodeParams.currentBlockHeight) match { + case Left(f) => + rollbackFundingAttempt(signingSession.fundingTx.tx, Nil) + stay() using d.copy(spliceStatus = SpliceStatus.SpliceAborted) sending TxAbort(d.channelId, f.getMessage) + case Right(signingSession1) => signingSession1 match { + case signingSession1: InteractiveTxSigningSession.WaitingForSigs => + // In theory we don't have to store their commit_sig here, as they would re-send it if we disconnect, but + // it is more consistent with the case where we send our tx_signatures first. + val d1 = d.copy(spliceStatus = SpliceStatus.SpliceWaitingForSigs(signingSession1)) + stay() using d1 storing() + case signingSession1: InteractiveTxSigningSession.SendingSigs => + // We don't have their tx_sigs, but they have ours, and could publish the funding tx without telling us. + // That's why we move on immediately to the next step, and will update our unsigned funding tx when we + // receive their tx_sigs. + val minDepth_opt = d.commitments.params.minDepth(nodeParams.channelConf.minDepth) + watchFundingConfirmed(signingSession.fundingTx.txId, minDepth_opt, delay_opt = None) + val commitments1 = d.commitments.add(signingSession1.commitment) + val d1 = d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NoSplice) + stay() using d1 storing() sending signingSession1.localSigs calling endQuiescence(d1) + } + } + case (_, sig: CommitSig) if d.commitments.ignoreRetransmittedCommitSig(sig) => + // If our peer hasn't implemented https://github.com/lightning/bolts/pull/1214, they may retransmit commit_sig + // even though we've already received it and haven't requested a retransmission. It is safe to simply ignore + // this commit_sig while we wait for peers to correctly implemented commit_sig retransmission, at which point + // we should be able to get rid of this edge case. + // Note that the funding transaction may have confirmed while we were reconnecting. + log.info("ignoring commit_sig, we're still waiting for tx_signatures") + stay() + case _ => + // NB: in all other cases we process the commit_sigs normally. We could do a full pattern matching on all + // splice statuses, but it would force us to handle every corner case where our peer doesn't behave correctly + // whereas they will all simply lead to a force-close. + d.commitments.receiveCommit(commit, channelKeys) match { + case Right((commitments1, revocation)) => + log.debug("received a new sig, spec:\n{}", commitments1.latest.specs2String) + if (commitments1.changes.localHasChanges) { + // if we have newly acknowledged changes let's sign them + self ! CMD_SIGN() } - case _ if d.commitments.ignoreRetransmittedCommitSig(commit) => - // We haven't received our peer's tx_signatures for the latest funding transaction and asked them to resend it on reconnection. - // They also resend their corresponding commit_sig, but we have already received it so we should ignore it. - // Note that the funding transaction may have confirmed while we were reconnecting. - log.info("ignoring commit_sig, we're still waiting for tx_signatures") - stay() - case _ => - // NB: in all other cases we process the commit_sig normally. We could do a full pattern matching on all - // splice statuses, but it would force us to handle corner cases like race condition between splice_init - // and a non-splice commit_sig - d.commitments.receiveCommit(sigs, channelKeys) match { - case Right((commitments1, revocation)) => - log.debug("received a new sig, spec:\n{}", commitments1.latest.specs2String) - if (commitments1.changes.localHasChanges) { - // if we have newly acknowledged changes let's sign them - self ! CMD_SIGN() - } - if (d.commitments.availableBalanceForSend != commitments1.availableBalanceForSend) { - // we send this event only when our balance changes - context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.aliases, commitments1, d.lastAnnouncement_opt)) - } - context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) - // If we're now quiescent, we may send our stfu message. - val (d1, toSend) = d.spliceStatus match { - case SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.QuiescenceRequested) if commitments1.localIsQuiescent => - val stfu = Stfu(d.channelId, initiator = true) - val spliceStatus1 = SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.SentStfu(stfu)) - (d.copy(commitments = commitments1, spliceStatus = spliceStatus1), Seq(revocation, stfu)) - case SpliceStatus.NegotiatingQuiescence(_, _: QuiescenceNegotiation.NonInitiator.ReceivedStfu) if commitments1.localIsQuiescent => - val stfu = Stfu(d.channelId, initiator = false) - (d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NonInitiatorQuiescent), Seq(revocation, stfu)) - case _ => - (d.copy(commitments = commitments1), Seq(revocation)) - } - stay() using d1 storing() sending toSend - case Left(cause) => handleLocalError(cause, d, Some(commit)) + if (d.commitments.availableBalanceForSend != commitments1.availableBalanceForSend) { + // we send this event only when our balance changes + context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.aliases, commitments1, d.lastAnnouncement_opt)) + } + context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) + // If we're now quiescent, we may send our stfu message. + val (d1, toSend) = d.spliceStatus match { + case SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.QuiescenceRequested) if commitments1.localIsQuiescent => + val stfu = Stfu(d.channelId, initiator = true) + val spliceStatus1 = SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.SentStfu(stfu)) + (d.copy(commitments = commitments1, spliceStatus = spliceStatus1), Seq(revocation, stfu)) + case SpliceStatus.NegotiatingQuiescence(_, _: QuiescenceNegotiation.NonInitiator.ReceivedStfu) if commitments1.localIsQuiescent => + val stfu = Stfu(d.channelId, initiator = false) + (d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NonInitiatorQuiescent), Seq(revocation, stfu)) + case _ => + (d.copy(commitments = commitments1), Seq(revocation)) } + stay() using d1 storing() sending toSend + case Left(cause) => handleLocalError(cause, d, Some(commit)) } - case None => stay() } case Event(revocation: RevokeAndAck, d: DATA_NORMAL) => @@ -1574,36 +1570,32 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall stay() } - case Event(commit: CommitSig, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => - aggregateSigs(commit) match { - case Some(sigs) => - d.commitments.receiveCommit(sigs, channelKeys) match { - case Right((commitments1, revocation)) => - // we always reply with a revocation - log.debug("received a new sig:\n{}", commitments1.latest.specs2String) - context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) - if (commitments1.hasNoPendingHtlcsOrFeeUpdate) { - if (Features.canUseFeature(d.commitments.params.localParams.initFeatures, d.commitments.params.remoteParams.initFeatures, Features.SimpleClose)) { - val (d1, closingComplete_opt) = startSimpleClose(d.commitments, localShutdown, remoteShutdown, closeStatus) - goto(NEGOTIATING_SIMPLE) using d1 storing() sending revocation +: closingComplete_opt.toSeq - } else if (d.commitments.params.localParams.paysClosingFees) { - // we pay the closing fees, so we initiate the negotiation by sending the first closing_signed - val (closingTx, closingSigned) = MutualClose.makeFirstClosingTx(channelKeys, commitments1.latest, localShutdown.scriptPubKey, remoteShutdown.scriptPubKey, nodeParams.currentFeeratesForFundingClosing, nodeParams.onChainFeeConf, d.closeStatus.feerates_opt) - goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, List(List(ClosingTxProposed(closingTx, closingSigned))), bestUnpublishedClosingTx_opt = None) storing() sending revocation :: closingSigned :: Nil - } else { - // we are not the channel initiator, will wait for their closing_signed - goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, closingTxProposed = List(List()), bestUnpublishedClosingTx_opt = None) storing() sending revocation - } - } else { - if (commitments1.changes.localHasChanges) { - // if we have newly acknowledged changes let's sign them - self ! CMD_SIGN() - } - stay() using d.copy(commitments = commitments1) storing() sending revocation - } - case Left(cause) => handleLocalError(cause, d, Some(commit)) + case Event(commit: CommitSigs, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => + d.commitments.receiveCommit(commit, channelKeys) match { + case Right((commitments1, revocation)) => + // we always reply with a revocation + log.debug("received a new sig:\n{}", commitments1.latest.specs2String) + context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) + if (commitments1.hasNoPendingHtlcsOrFeeUpdate) { + if (Features.canUseFeature(d.commitments.params.localParams.initFeatures, d.commitments.params.remoteParams.initFeatures, Features.SimpleClose)) { + val (d1, closingComplete_opt) = startSimpleClose(d.commitments, localShutdown, remoteShutdown, closeStatus) + goto(NEGOTIATING_SIMPLE) using d1 storing() sending revocation +: closingComplete_opt.toSeq + } else if (d.commitments.params.localParams.paysClosingFees) { + // we pay the closing fees, so we initiate the negotiation by sending the first closing_signed + val (closingTx, closingSigned) = MutualClose.makeFirstClosingTx(channelKeys, commitments1.latest, localShutdown.scriptPubKey, remoteShutdown.scriptPubKey, nodeParams.currentFeeratesForFundingClosing, nodeParams.onChainFeeConf, d.closeStatus.feerates_opt) + goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, List(List(ClosingTxProposed(closingTx, closingSigned))), bestUnpublishedClosingTx_opt = None) storing() sending revocation :: closingSigned :: Nil + } else { + // we are not the channel initiator, will wait for their closing_signed + goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, closingTxProposed = List(List()), bestUnpublishedClosingTx_opt = None) storing() sending revocation + } + } else { + if (commitments1.changes.localHasChanges) { + // if we have newly acknowledged changes let's sign them + self ! CMD_SIGN() + } + stay() using d.copy(commitments = commitments1) storing() sending revocation } - case None => stay() + case Left(cause) => handleLocalError(cause, d, Some(commit)) } case Event(revocation: RevokeAndAck, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => @@ -3020,7 +3012,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall /** On disconnection we clear up stashes. */ onTransition { case _ -> OFFLINE => - sigStash = Nil announcementSigsStash = Map.empty announcementSigsSent = Set.empty spliceLockedSent = Map.empty[TxId, Long] @@ -3037,19 +3028,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall 888 888 d88P 888 888 Y888 8888888P" 88888888 8888888888 888 T88b "Y8888P" */ - /** For splices we will send one commit_sig per active commitments. */ - private def aggregateSigs(commit: CommitSig): Option[CommitSigs] = { - sigStash = sigStash :+ commit - log.debug("received sig for batch of size={}", commit.batchSize) - if (sigStash.size == commit.batchSize) { - val sigs = CommitSigs(sigStash) - sigStash = Nil - Some(sigs) - } else { - None - } - } - private def handleCurrentFeerate(c: CurrentFeerates, d: ChannelDataWithCommitments) = { val commitments = d.commitments.latest val networkFeeratePerKw = nodeParams.onChainFeeConf.getCommitmentFeerate(nodeParams.currentBitcoinCoreFeerates, remoteNodeId, d.commitments.params.commitmentFormat, commitments.capacity) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index 4ecc3983dc..8f2fe0254c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -18,8 +18,8 @@ package fr.acinq.eclair.io import akka.actor.{ActorRef, FSM, OneForOneStrategy, PoisonPill, Props, Stash, SupervisorStrategy, Terminated} import akka.event.Logging.MDC -import fr.acinq.bitcoin.scalacompat.BlockHash import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.TransportHandler @@ -343,10 +343,48 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A stay() case Event(msg: LightningMessage, d: ConnectedData) => - // we acknowledge and pass all other messages to the peer + // We immediately acknowledge all other messages. d.transport ! TransportHandler.ReadAck(msg) - d.peer ! msg - stay() + // We immediately forward messages to the peer, unless they are part of a batch, in which case we wait to + // receive the whole batch before forwarding. + msg match { + case msg: CommitSig => + msg.tlvStream.get[CommitSigTlv.BatchTlv].map(_.size) match { + case Some(batchSize) if batchSize > 25 => + log.warning("received legacy batch of commit_sig exceeding our threshold ({} > 25), processing messages individually", batchSize) + // We don't want peers to be able to exhaust our memory by sending batches of dummy messages that we keep in RAM. + d.peer ! msg + stay() + case Some(batchSize) if batchSize > 1 => + d.legacyCommitSigBatch_opt match { + case Some(pending) if pending.channelId != msg.channelId || pending.batchSize != batchSize => + log.warning("received invalid commit_sig batch while a different batch isn't complete") + // This should never happen, otherwise it will likely lead to a force-close. + d.peer ! CommitSigBatch(pending.received) + stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg)))) + case Some(pending) => + val received1 = pending.received :+ msg + if (received1.size == batchSize) { + log.debug("received last commit_sig in legacy batch for channel_id={}", msg.channelId) + d.peer ! CommitSigBatch(received1) + stay() using d.copy(legacyCommitSigBatch_opt = None) + } else { + log.debug("received commit_sig {}/{} in legacy batch for channel_id={}", received1.size, batchSize, msg.channelId) + stay() using d.copy(legacyCommitSigBatch_opt = Some(pending.copy(received = received1))) + } + case None => + log.debug("received first commit_sig in legacy batch of size {} for channel_id={}", batchSize, msg.channelId) + stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg)))) + } + case _ => + log.debug("received individual commit_sig for channel_id={}", msg.channelId) + d.peer ! msg + stay() + } + case _ => + d.peer ! msg + stay() + } case Event(readAck: TransportHandler.ReadAck, d: ConnectedData) => // we just forward acks to the transport (e.g. from the router) @@ -566,8 +604,19 @@ object PeerConnection { case class AuthenticatingData(pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport case class BeforeInitData(remoteNodeId: PublicKey, pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport case class InitializingData(chainHash: BlockHash, pendingAuth: PendingAuth, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, doSync: Boolean, isPersistent: Boolean) extends Data with HasTransport - case class ConnectedData(chainHash: BlockHash, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, rebroadcastDelay: FiniteDuration, gossipTimestampFilter: Option[GossipTimestampFilter] = None, behavior: Behavior = Behavior(), expectedPong_opt: Option[ExpectedPong] = None, isPersistent: Boolean) extends Data with HasTransport - + case class ConnectedData(chainHash: BlockHash, + remoteNodeId: PublicKey, + transport: ActorRef, + peer: ActorRef, + localInit: protocol.Init, remoteInit: protocol.Init, + rebroadcastDelay: FiniteDuration, + gossipTimestampFilter: Option[GossipTimestampFilter] = None, + behavior: Behavior = Behavior(), + expectedPong_opt: Option[ExpectedPong] = None, + legacyCommitSigBatch_opt: Option[PendingCommitSigBatch] = None, + isPersistent: Boolean) extends Data with HasTransport + + case class PendingCommitSigBatch(channelId: ByteVector32, batchSize: Int, received: Seq[CommitSig]) case class ExpectedPong(ping: Ping, timestamp: TimestampMilli = TimestampMilli.now()) case class PingTimeout(ping: Ping) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 331dfc87c7..2c6a95b43b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -440,9 +440,7 @@ object CommitSigs { case class CommitSig(channelId: ByteVector32, signature: ByteVector64, htlcSignatures: List[ByteVector64], - tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs { - val batchSize: Int = tlvStream.get[CommitSigTlv.BatchTlv].map(_.size).getOrElse(1) -} + tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs case class CommitSigBatch(messages: Seq[CommitSig]) extends CommitSigs { require(messages.map(_.channelId).toSet.size == 1, "commit_sig messages in a batch must be for the same channel") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index b7f8e23f9f..3857770251 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -463,23 +463,17 @@ trait ChannelStateTestsBase extends Assertions with Eventually { val rHasChanges = r.stateData.asInstanceOf[ChannelDataWithCommitments].commitments.changes.localHasChanges s ! CMD_SIGN(Some(sender.ref)) sender.expectMsgType[RES_SUCCESS[CMD_SIGN]] - s2r.expectMsgType[CommitSigs] match { - case sig: CommitSig => s2r.forward(r, sig) - case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig)) - } + s2r.expectMsgType[CommitSigs] + s2r.forward(r) r2s.expectMsgType[RevokeAndAck] r2s.forward(s) - r2s.expectMsgType[CommitSigs] match { - case sig: CommitSig => r2s.forward(s, sig) - case batch: CommitSigBatch => batch.messages.foreach(sig => r2s.forward(s, sig)) - } + r2s.expectMsgType[CommitSigs] + r2s.forward(s) s2r.expectMsgType[RevokeAndAck] s2r.forward(r) if (rHasChanges) { - s2r.expectMsgType[CommitSigs] match { - case sig: CommitSig => s2r.forward(r, sig) - case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig)) - } + s2r.expectMsgType[CommitSigs] + s2r.forward(r) r2s.expectMsgType[RevokeAndAck] r2s.forward(s) eventually { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala index e77b87d36e..f77d84c5d6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala @@ -1519,12 +1519,12 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice ! CMD_SIGN() val sigsA = alice2bob.expectMsgType[CommitSigBatch] assert(sigsA.batchSize == 2) - sigsA.messages.foreach(sig => alice2bob.forward(bob, sig)) + alice2bob.forward(bob, sigsA) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) val sigsB = bob2alice.expectMsgType[CommitSigBatch] assert(sigsB.batchSize == 2) - sigsB.messages.foreach(sig => bob2alice.forward(alice, sig)) + bob2alice.forward(alice, sigsB) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1548,12 +1548,12 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.forward(bob) val sigsA = alice2bob.expectMsgType[CommitSigBatch] assert(sigsA.batchSize == 2) - sigsA.messages.foreach(sig => alice2bob.forward(bob, sig)) + alice2bob.forward(bob, sigsA) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) val sigsB = bob2alice.expectMsgType[CommitSigBatch] assert(sigsB.batchSize == 2) - sigsB.messages.foreach(sig => bob2alice.forward(alice, sig)) + bob2alice.forward(alice, sigsB) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1672,16 +1672,14 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice ! CMD_SIGN() val commitSigsAlice = alice2bob.expectMsgType[CommitSigBatch] assert(commitSigsAlice.batchSize == 3) - alice2bob.forward(bob, commitSigsAlice.messages(0)) bob ! WatchPublishedTriggered(spliceTx2) val spliceLockedBob = bob2alice.expectMsgType[SpliceLocked] assert(spliceLockedBob.fundingTxId == spliceTx2.txid) bob2alice.forward(alice, spliceLockedBob) - alice2bob.forward(bob, commitSigsAlice.messages(1)) - alice2bob.forward(bob, commitSigsAlice.messages(2)) + alice2bob.forward(bob, commitSigsAlice) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 1) + bob2alice.expectMsgType[CommitSig] bob2alice.forward(alice) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) @@ -3335,13 +3333,13 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik bob ! CMD_SIGN() inside(bob2alice.expectMsgType[CommitSigBatch]) { batch => assert(batch.batchSize == 3) - batch.messages.foreach(sig => bob2alice.forward(alice, sig)) + bob2alice.forward(alice, batch) } alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) inside(alice2bob.expectMsgType[CommitSigBatch]) { batch => assert(batch.batchSize == 3) - batch.messages.foreach(sig => alice2bob.forward(bob, sig)) + alice2bob.forward(bob, batch) } bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index eabae39ee0..dcefce23aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -23,6 +23,7 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} import fr.acinq.eclair.Features._ import fr.acinq.eclair.TestConstants._ +import fr.acinq.eclair.TestUtils.randomTxId import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.ConnectionDown import fr.acinq.eclair.message.OnionMessages.{Recipient, buildMessage} @@ -348,6 +349,85 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi transport.expectNoMessage(100 millis) } + test("receive legacy batch of commit_sig messages") { f => + import f._ + connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) + + // We receive a batch of commit_sig messages from a first channel. + val channelId1 = randomBytes32() + val commitSigs1 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs1.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs1.head)) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs1.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs1.last)) + peer.expectMsg(CommitSigBatch(commitSigs1)) + + // We receive a batch of commit_sig messages from a second channel. + val channelId2 = randomBytes32() + val commitSigs2 = Seq( + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + ) + commitSigs2.dropRight(1).foreach(commitSig => { + transport.send(peerConnection, commitSig) + transport.expectMsg(TransportHandler.ReadAck(commitSig)) + }) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs2.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs2.last)) + peer.expectMsg(CommitSigBatch(commitSigs2)) + + // We receive another batch of commit_sig messages from the first channel, with unrelated messages in the batch. + val commitSigs3 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs3.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs3.head)) + val spliceLocked1 = SpliceLocked(channelId1, randomTxId()) + transport.send(peerConnection, spliceLocked1) + transport.expectMsg(TransportHandler.ReadAck(spliceLocked1)) + peer.expectMsg(spliceLocked1) + val spliceLocked2 = SpliceLocked(channelId2, randomTxId()) + transport.send(peerConnection, spliceLocked2) + transport.expectMsg(TransportHandler.ReadAck(spliceLocked2)) + peer.expectMsg(spliceLocked2) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs3.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs3.last)) + peer.expectMsg(CommitSigBatch(commitSigs3)) + + // We start receiving a batch of commit_sig messages from the first channel, interleaved with a batch from the second + // channel, which is not supported. + val commitSigs4 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs4.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4.head)) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs4(1)) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4(1))) + peer.expectMsg(CommitSigBatch(commitSigs4.take(1))) + transport.send(peerConnection, commitSigs4.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4.last)) + peer.expectMsg(CommitSigBatch(commitSigs4.tail)) + + // We receive a batch that exceeds our threshold: we process them individually. + val invalidCommitSigs = (0 until 30).map(_ => CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(30)))) + invalidCommitSigs.foreach(commitSig => { + transport.send(peerConnection, commitSig) + transport.expectMsg(TransportHandler.ReadAck(commitSig)) + peer.expectMsg(commitSig) + }) + } + test("react to peer's bad behavior") { f => import f._ val probe = TestProbe()