diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index 721612c515..e2f4a89d1a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -993,7 +993,7 @@ case class Commitments(params: ChannelParams, } } - def sendCommit(channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, Seq[CommitSig])] = { + def sendCommit(channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, CommitSigs)] = { remoteNextCommitInfo match { case Right(_) if !changes.localHasChanges => Left(CannotSignWithoutChanges(channelId)) case Right(remoteNextPerCommitmentPoint) => @@ -1007,21 +1007,24 @@ case class Commitments(params: ChannelParams, active = active1, remoteNextCommitInfo = Left(WaitForRev(localCommitIndex)) ) - Right(commitments1, sigs) + Right(commitments1, CommitSigs(sigs)) case Left(_) => Left(CannotSignBeforeRevocation(channelId)) } } - def receiveCommit(commits: Seq[CommitSig], channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, RevokeAndAck)] = { + def receiveCommit(commitSigs: CommitSigs, channelKeys: ChannelKeys)(implicit log: LoggingAdapter): Either[ChannelException, (Commitments, RevokeAndAck)] = { // We may receive more commit_sig than the number of active commitments, because there can be a race where we send // splice_locked while our peer is sending us a batch of commit_sig. When that happens, we simply need to discard // the commit_sig that belong to commitments we deactivated. - if (commits.size < active.size) { - return Left(CommitSigCountMismatch(channelId, active.size, commits.size)) + val sigs = commitSigs match { + case batch: CommitSigBatch if batch.batchSize < active.size => return Left(CommitSigCountMismatch(channelId, active.size, batch.batchSize)) + case batch: CommitSigBatch => batch.messages + case _: CommitSig if active.size > 1 => return Left(CommitSigCountMismatch(channelId, active.size, 1)) + case commitSig: CommitSig => Seq(commitSig) } val commitKeys = LocalCommitmentKeys(params, channelKeys, localCommitIndex + 1) // Signatures are sent in order (most recent first), calling `zip` will drop trailing sigs that are for deactivated/pruned commitments. - val active1 = active.zip(commits).map { case (commitment, commit) => + val active1 = active.zip(sigs).map { case (commitment, commit) => commitment.receiveCommit(params, channelKeys, commitKeys, changes, commit) match { case Left(f) => return Left(f) case Right(commitment1) => commitment1 @@ -1156,7 +1159,7 @@ case class Commitments(params: ChannelParams, case ChannelSpendSignature.IndividualSignature(latestRemoteSig) => latestRemoteSig == commitSig.signature case ChannelSpendSignature.PartialSignatureWithNonce(_, _) => ??? } - params.channelFeatures.hasFeature(Features.DualFunding) && commitSig.batchSize == 1 && isLatestSig + params.channelFeatures.hasFeature(Features.DualFunding) && isLatestSig } def localFundingSigs(fundingTxId: TxId): Option[TxSignatures] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index 65afc6e37f..e7048a0d7f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -478,14 +478,14 @@ object Helpers { // we just sent a new commit_sig but they didn't receive it // we resend the same updates and the same sig, and preserve the same ordering val signedUpdates = commitments.changes.localChanges.signed - val commitSigs = commitments.active.flatMap(_.nextRemoteCommit_opt).map(_.sig) + val commitSigs = CommitSigs(commitments.active.flatMap(_.nextRemoteCommit_opt).map(_.sig)) retransmitRevocation_opt match { case None => - SyncResult.Success(retransmit = signedUpdates ++ commitSigs) + SyncResult.Success(retransmit = signedUpdates :+ commitSigs) case Some(revocation) if commitments.localCommitIndex > waitingForRevocation.sentAfterLocalCommitIndex => - SyncResult.Success(retransmit = signedUpdates ++ commitSigs ++ Seq(revocation)) + SyncResult.Success(retransmit = signedUpdates :+ commitSigs :+ revocation) case Some(revocation) => - SyncResult.Success(retransmit = Seq(revocation) ++ signedUpdates ++ commitSigs) + SyncResult.Success(retransmit = revocation +: signedUpdates :+ commitSigs) } case Left(_) if remoteChannelReestablish.nextLocalCommitmentNumber == (commitments.nextRemoteCommitIndex + 1) => // we just sent a new commit_sig, they have received it but we haven't received their revocation diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala index 44241283ba..ba5bfeb664 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala @@ -211,8 +211,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall // choose to not make this an Option (that would be None before the first connection), and instead embrace the fact // that the active connection may point to dead letters at all time var activeConnection = context.system.deadLetters - // we aggregate sigs for splices before processing - var sigStash = Seq.empty[CommitSig] // we stash announcement_signatures if we receive them earlier than expected var announcementSigsStash = Map.empty[RealShortChannelId, AnnouncementSignatures] // we record the announcement_signatures messages we already sent to avoid unnecessary retransmission @@ -581,77 +579,75 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall stay() } - case Event(commit: CommitSig, d: DATA_NORMAL) => - aggregateSigs(commit) match { - case Some(sigs) => - d.spliceStatus match { - case s: SpliceStatus.SpliceInProgress => - log.debug("received their commit_sig, deferring message") - stay() using d.copy(spliceStatus = s.copy(remoteCommitSig = Some(commit))) - case SpliceStatus.SpliceAborted => - log.warning("received commit_sig after sending tx_abort, they probably sent it before receiving our tx_abort, ignoring...") - stay() - case SpliceStatus.SpliceWaitingForSigs(signingSession) => - signingSession.receiveCommitSig(d.commitments.params, channelKeys, commit, nodeParams.currentBlockHeight) match { - case Left(f) => - rollbackFundingAttempt(signingSession.fundingTx.tx, Nil) - stay() using d.copy(spliceStatus = SpliceStatus.SpliceAborted) sending TxAbort(d.channelId, f.getMessage) - case Right(signingSession1) => signingSession1 match { - case signingSession1: InteractiveTxSigningSession.WaitingForSigs => - // In theory we don't have to store their commit_sig here, as they would re-send it if we disconnect, but - // it is more consistent with the case where we send our tx_signatures first. - val d1 = d.copy(spliceStatus = SpliceStatus.SpliceWaitingForSigs(signingSession1)) - stay() using d1 storing() - case signingSession1: InteractiveTxSigningSession.SendingSigs => - // We don't have their tx_sigs, but they have ours, and could publish the funding tx without telling us. - // That's why we move on immediately to the next step, and will update our unsigned funding tx when we - // receive their tx_sigs. - val minDepth_opt = d.commitments.params.minDepth(nodeParams.channelConf.minDepth) - watchFundingConfirmed(signingSession.fundingTx.txId, minDepth_opt, delay_opt = None) - val commitments1 = d.commitments.add(signingSession1.commitment) - val d1 = d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NoSplice) - stay() using d1 storing() sending signingSession1.localSigs calling endQuiescence(d1) - } + case Event(commit: CommitSigs, d: DATA_NORMAL) => + (d.spliceStatus, commit) match { + case (s: SpliceStatus.SpliceInProgress, sig: CommitSig) => + log.debug("received their commit_sig, deferring message") + stay() using d.copy(spliceStatus = s.copy(remoteCommitSig = Some(sig))) + case (SpliceStatus.SpliceAborted, sig: CommitSig) => + log.warning("received commit_sig after sending tx_abort, they probably sent it before receiving our tx_abort, ignoring...") + stay() + case (SpliceStatus.SpliceWaitingForSigs(signingSession), sig: CommitSig) => + signingSession.receiveCommitSig(d.commitments.params, channelKeys, sig, nodeParams.currentBlockHeight) match { + case Left(f) => + rollbackFundingAttempt(signingSession.fundingTx.tx, Nil) + stay() using d.copy(spliceStatus = SpliceStatus.SpliceAborted) sending TxAbort(d.channelId, f.getMessage) + case Right(signingSession1) => signingSession1 match { + case signingSession1: InteractiveTxSigningSession.WaitingForSigs => + // In theory we don't have to store their commit_sig here, as they would re-send it if we disconnect, but + // it is more consistent with the case where we send our tx_signatures first. + val d1 = d.copy(spliceStatus = SpliceStatus.SpliceWaitingForSigs(signingSession1)) + stay() using d1 storing() + case signingSession1: InteractiveTxSigningSession.SendingSigs => + // We don't have their tx_sigs, but they have ours, and could publish the funding tx without telling us. + // That's why we move on immediately to the next step, and will update our unsigned funding tx when we + // receive their tx_sigs. + val minDepth_opt = d.commitments.params.minDepth(nodeParams.channelConf.minDepth) + watchFundingConfirmed(signingSession.fundingTx.txId, minDepth_opt, delay_opt = None) + val commitments1 = d.commitments.add(signingSession1.commitment) + val d1 = d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NoSplice) + stay() using d1 storing() sending signingSession1.localSigs calling endQuiescence(d1) + } + } + case (_, sig: CommitSig) if d.commitments.ignoreRetransmittedCommitSig(sig) => + // If our peer hasn't implemented https://github.com/lightning/bolts/pull/1214, they may retransmit commit_sig + // even though we've already received it and haven't requested a retransmission. It is safe to simply ignore + // this commit_sig while we wait for peers to correctly implemented commit_sig retransmission, at which point + // we should be able to get rid of this edge case. + // Note that the funding transaction may have confirmed while we were reconnecting. + log.info("ignoring commit_sig, we're still waiting for tx_signatures") + stay() + case _ => + // NB: in all other cases we process the commit_sigs normally. We could do a full pattern matching on all + // splice statuses, but it would force us to handle every corner case where our peer doesn't behave correctly + // whereas they will all simply lead to a force-close. + d.commitments.receiveCommit(commit, channelKeys) match { + case Right((commitments1, revocation)) => + log.debug("received a new sig, spec:\n{}", commitments1.latest.specs2String) + if (commitments1.changes.localHasChanges) { + // if we have newly acknowledged changes let's sign them + self ! CMD_SIGN() } - case _ if d.commitments.ignoreRetransmittedCommitSig(commit) => - // We haven't received our peer's tx_signatures for the latest funding transaction and asked them to resend it on reconnection. - // They also resend their corresponding commit_sig, but we have already received it so we should ignore it. - // Note that the funding transaction may have confirmed while we were reconnecting. - log.info("ignoring commit_sig, we're still waiting for tx_signatures") - stay() - case _ => - // NB: in all other cases we process the commit_sig normally. We could do a full pattern matching on all - // splice statuses, but it would force us to handle corner cases like race condition between splice_init - // and a non-splice commit_sig - d.commitments.receiveCommit(sigs, channelKeys) match { - case Right((commitments1, revocation)) => - log.debug("received a new sig, spec:\n{}", commitments1.latest.specs2String) - if (commitments1.changes.localHasChanges) { - // if we have newly acknowledged changes let's sign them - self ! CMD_SIGN() - } - if (d.commitments.availableBalanceForSend != commitments1.availableBalanceForSend) { - // we send this event only when our balance changes - context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.aliases, commitments1, d.lastAnnouncement_opt)) - } - context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) - // If we're now quiescent, we may send our stfu message. - val (d1, toSend) = d.spliceStatus match { - case SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.QuiescenceRequested) if commitments1.localIsQuiescent => - val stfu = Stfu(d.channelId, initiator = true) - val spliceStatus1 = SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.SentStfu(stfu)) - (d.copy(commitments = commitments1, spliceStatus = spliceStatus1), Seq(revocation, stfu)) - case SpliceStatus.NegotiatingQuiescence(_, _: QuiescenceNegotiation.NonInitiator.ReceivedStfu) if commitments1.localIsQuiescent => - val stfu = Stfu(d.channelId, initiator = false) - (d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NonInitiatorQuiescent), Seq(revocation, stfu)) - case _ => - (d.copy(commitments = commitments1), Seq(revocation)) - } - stay() using d1 storing() sending toSend - case Left(cause) => handleLocalError(cause, d, Some(commit)) + if (d.commitments.availableBalanceForSend != commitments1.availableBalanceForSend) { + // we send this event only when our balance changes + context.system.eventStream.publish(AvailableBalanceChanged(self, d.channelId, d.aliases, commitments1, d.lastAnnouncement_opt)) + } + context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) + // If we're now quiescent, we may send our stfu message. + val (d1, toSend) = d.spliceStatus match { + case SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.QuiescenceRequested) if commitments1.localIsQuiescent => + val stfu = Stfu(d.channelId, initiator = true) + val spliceStatus1 = SpliceStatus.NegotiatingQuiescence(cmd_opt, QuiescenceNegotiation.Initiator.SentStfu(stfu)) + (d.copy(commitments = commitments1, spliceStatus = spliceStatus1), Seq(revocation, stfu)) + case SpliceStatus.NegotiatingQuiescence(_, _: QuiescenceNegotiation.NonInitiator.ReceivedStfu) if commitments1.localIsQuiescent => + val stfu = Stfu(d.channelId, initiator = false) + (d.copy(commitments = commitments1, spliceStatus = SpliceStatus.NonInitiatorQuiescent), Seq(revocation, stfu)) + case _ => + (d.copy(commitments = commitments1), Seq(revocation)) } + stay() using d1 storing() sending toSend + case Left(cause) => handleLocalError(cause, d, Some(commit)) } - case None => stay() } case Event(revocation: RevokeAndAck, d: DATA_NORMAL) => @@ -1574,36 +1570,32 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall stay() } - case Event(commit: CommitSig, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => - aggregateSigs(commit) match { - case Some(sigs) => - d.commitments.receiveCommit(sigs, channelKeys) match { - case Right((commitments1, revocation)) => - // we always reply with a revocation - log.debug("received a new sig:\n{}", commitments1.latest.specs2String) - context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) - if (commitments1.hasNoPendingHtlcsOrFeeUpdate) { - if (Features.canUseFeature(d.commitments.params.localParams.initFeatures, d.commitments.params.remoteParams.initFeatures, Features.SimpleClose)) { - val (d1, closingComplete_opt) = startSimpleClose(d.commitments, localShutdown, remoteShutdown, closeStatus) - goto(NEGOTIATING_SIMPLE) using d1 storing() sending revocation +: closingComplete_opt.toSeq - } else if (d.commitments.params.localParams.paysClosingFees) { - // we pay the closing fees, so we initiate the negotiation by sending the first closing_signed - val (closingTx, closingSigned) = MutualClose.makeFirstClosingTx(channelKeys, commitments1.latest, localShutdown.scriptPubKey, remoteShutdown.scriptPubKey, nodeParams.currentFeeratesForFundingClosing, nodeParams.onChainFeeConf, d.closeStatus.feerates_opt) - goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, List(List(ClosingTxProposed(closingTx, closingSigned))), bestUnpublishedClosingTx_opt = None) storing() sending revocation :: closingSigned :: Nil - } else { - // we are not the channel initiator, will wait for their closing_signed - goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, closingTxProposed = List(List()), bestUnpublishedClosingTx_opt = None) storing() sending revocation - } - } else { - if (commitments1.changes.localHasChanges) { - // if we have newly acknowledged changes let's sign them - self ! CMD_SIGN() - } - stay() using d.copy(commitments = commitments1) storing() sending revocation - } - case Left(cause) => handleLocalError(cause, d, Some(commit)) + case Event(commit: CommitSigs, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => + d.commitments.receiveCommit(commit, channelKeys) match { + case Right((commitments1, revocation)) => + // we always reply with a revocation + log.debug("received a new sig:\n{}", commitments1.latest.specs2String) + context.system.eventStream.publish(ChannelSignatureReceived(self, commitments1)) + if (commitments1.hasNoPendingHtlcsOrFeeUpdate) { + if (Features.canUseFeature(d.commitments.params.localParams.initFeatures, d.commitments.params.remoteParams.initFeatures, Features.SimpleClose)) { + val (d1, closingComplete_opt) = startSimpleClose(d.commitments, localShutdown, remoteShutdown, closeStatus) + goto(NEGOTIATING_SIMPLE) using d1 storing() sending revocation +: closingComplete_opt.toSeq + } else if (d.commitments.params.localParams.paysClosingFees) { + // we pay the closing fees, so we initiate the negotiation by sending the first closing_signed + val (closingTx, closingSigned) = MutualClose.makeFirstClosingTx(channelKeys, commitments1.latest, localShutdown.scriptPubKey, remoteShutdown.scriptPubKey, nodeParams.currentFeeratesForFundingClosing, nodeParams.onChainFeeConf, d.closeStatus.feerates_opt) + goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, List(List(ClosingTxProposed(closingTx, closingSigned))), bestUnpublishedClosingTx_opt = None) storing() sending revocation :: closingSigned :: Nil + } else { + // we are not the channel initiator, will wait for their closing_signed + goto(NEGOTIATING) using DATA_NEGOTIATING(commitments1, localShutdown, remoteShutdown, closingTxProposed = List(List()), bestUnpublishedClosingTx_opt = None) storing() sending revocation + } + } else { + if (commitments1.changes.localHasChanges) { + // if we have newly acknowledged changes let's sign them + self ! CMD_SIGN() + } + stay() using d.copy(commitments = commitments1) storing() sending revocation } - case None => stay() + case Left(cause) => handleLocalError(cause, d, Some(commit)) } case Event(revocation: RevokeAndAck, d@DATA_SHUTDOWN(_, localShutdown, remoteShutdown, closeStatus)) => @@ -3020,7 +3012,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall /** On disconnection we clear up stashes. */ onTransition { case _ -> OFFLINE => - sigStash = Nil announcementSigsStash = Map.empty announcementSigsSent = Set.empty spliceLockedSent = Map.empty[TxId, Long] @@ -3037,19 +3028,6 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall 888 888 d88P 888 888 Y888 8888888P" 88888888 8888888888 888 T88b "Y8888P" */ - /** For splices we will send one commit_sig per active commitments. */ - private def aggregateSigs(commit: CommitSig): Option[Seq[CommitSig]] = { - sigStash = sigStash :+ commit - log.debug("received sig for batch of size={}", commit.batchSize) - if (sigStash.size == commit.batchSize) { - val sigs = sigStash - sigStash = Nil - Some(sigs) - } else { - None - } - } - private def handleCurrentFeerate(c: CurrentFeerates, d: ChannelDataWithCommitments) = { val commitments = d.commitments.latest val networkFeeratePerKw = nodeParams.onChainFeeConf.getCommitmentFeerate(nodeParams.currentBitcoinCoreFeerates, remoteNodeId, d.commitments.params.commitmentFormat, commitments.capacity) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala index 32c8e965dd..c05106d51e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/TransportHandler.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.crypto.ChaCha20Poly1305.ChaCha20Poly1305Error import fr.acinq.eclair.crypto.Noise._ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes -import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, RoutingMessage} +import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, LightningMessage, RoutingMessage} import fr.acinq.eclair.{Diagnostics, FSMDiagnosticActorLogging, Logs, getSimpleClassName} import scodec.bits.ByteVector import scodec.{Attempt, Codec, DecodeResult} @@ -40,7 +40,7 @@ import scala.reflect.ClassTag /** * see BOLT #8 * This class handles the transport layer: - * - initial handshake. upon completion we will have a pair of cipher states (one for encryption, one for decryption) + * - initial handshake. upon completion we will have a pair of cipher states (one for encryption, one for decryption) * - encryption/decryption of messages * * Once the initial handshake has been completed successfully, the handler will create a listener actor with the @@ -50,23 +50,22 @@ import scala.reflect.ClassTag * @param rs remote node static public key (which must be known before we initiate communication) * @param connection actor that represents the other node's */ -class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[T]) extends Actor with FSMDiagnosticActorLogging[TransportHandler.State, TransportHandler.Data] { +class TransportHandler(keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[LightningMessage]) extends Actor with FSMDiagnosticActorLogging[TransportHandler.State, TransportHandler.Data] { // will hold the peer's public key once it is available (we don't know it right away in case of an incoming connection) var remoteNodeId_opt: Option[PublicKey] = rs.map(PublicKey(_)) - val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter + private val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter - def diag(message: T, direction: String): Unit = { - require(direction == "IN" || direction == "OUT") + private def logMessage(message: LightningMessage, direction: String): Unit = { val channelId_opt = Logs.channelId(message) wireLog.mdc(Logs.mdc(LogCategory(message), remoteNodeId_opt, channelId_opt)) if (channelId_opt.isDefined) { // channel-related messages are logged as info - wireLog.info(s"$direction msg={}", message) + wireLog.info("{} msg={}", direction, message) } else { // other messages (e.g. routing gossip) are logged as debug - wireLog.debug(s"$direction msg={}", message) + wireLog.debug("{} msg={}", direction, message) } wireLog.clearMDC() } @@ -79,11 +78,11 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co def buf(message: ByteVector): ByteString = ByteString.fromArray(message.toArray) // it means we initiate the dialog - val isWriter = rs.isDefined + private val isWriter = rs.isDefined context.watch(connection) - val reader = if (isWriter) { + private val reader = if (isWriter) { val state = makeWriter(keyPair, rs.get) val (state1, message, None) = state.write(ByteVector.empty) log.debug(s"sending prefix + $message") @@ -93,12 +92,12 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co makeReader(keyPair) } - def decodeAndSendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[T, Int] = { + private def decodeAndSendToListener(listener: ActorRef, plaintextMessages: Seq[ByteVector]): Map[LightningMessage, Int] = { log.debug("decoding {} plaintext messages", plaintextMessages.size) - var m: Map[T, Int] = Map() + var m = Map.empty[LightningMessage, Int] plaintextMessages.foreach(plaintext => codec.decode(plaintext.bits) match { case Attempt.Successful(DecodeResult(message, _)) => - diag(message, "IN") + logMessage(message, "IN") Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.IN).record(plaintext.size) listener ! message m += (message -> (m.getOrElse(message, 0) + 1)) @@ -132,25 +131,22 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co val nextStateData = WaitingForListenerData(Encryptor(ExtendedCipherState(enc, ck)), Decryptor(ExtendedCipherState(dec, ck), ciphertextLength = None, remainder)) goto(WaitingForListener) using nextStateData - case (writer, _, None) => { + case (writer, _, None) => writer.write(ByteVector.empty) match { - case (reader1, message, None) => { + case (reader1, message, None) => // we're still in the middle of the handshake process and the other end must first received our next // message before they can reply if (remainder.nonEmpty) throw UnexpectedDataDuringHandshake(ByteVector(remainder)) connection ! Tcp.Write(buf(TransportHandler.prefix +: message)) stay() using HandshakeData(reader1, remainder) - } - case (_, message, Some((enc, dec, ck))) => { + case (_, message, Some((enc, dec, ck))) => connection ! Tcp.Write(buf(TransportHandler.prefix +: message)) val remoteNodeId = PublicKey(writer.rs) remoteNodeId_opt = Some(remoteNodeId) context.parent ! HandshakeCompleted(remoteNodeId) val nextStateData = WaitingForListenerData(Encryptor(ExtendedCipherState(enc, ck)), Decryptor(ExtendedCipherState(dec, ck), ciphertextLength = None, remainder)) goto(WaitingForListener) using nextStateData - } } - } } } } @@ -169,13 +165,13 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co log.debug("no decoded messages, resuming reading") connection ! Tcp.ResumeReading } - goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[T], Queue.empty[T]), unackedReceived = unackedReceived1, unackedSent = None) + goto(Normal) using NormalData(d.encryptor, dec1, listener, sendBuffer = SendBuffer(Queue.empty[LightningMessage], Queue.empty[LightningMessage]), unackedReceived = unackedReceived1, unackedSent = None) } } when(Normal) { handleExceptions { - case Event(Tcp.Received(data), d: NormalData[T @unchecked]) => + case Event(Tcp.Received(data), d: NormalData) => log.debug("received chunk of size={}", data.size) val (dec1, plaintextMessages) = d.decryptor.copy(buffer = d.decryptor.buffer ++ data).decrypt() val unackedReceived1 = decodeAndSendToListener(d.listener, plaintextMessages) @@ -185,7 +181,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(decryptor = dec1, unackedReceived = unackedReceived1) - case Event(ReadAck(msg: T), d: NormalData[T @unchecked]) => + case Event(ReadAck(msg: LightningMessage), d: NormalData) => // how many occurrences of this message are still unacked? val remaining = d.unackedReceived.getOrElse(msg, 0) - 1 log.debug("acking message {}", msg) @@ -199,7 +195,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(unackedReceived = unackedReceived1) - case Event(t: T, d: NormalData[T @unchecked]) => + case Event(t: LightningMessage, d: NormalData) => if (d.sendBuffer.normalPriority.size + d.sendBuffer.lowPriority.size >= MAX_BUFFERED) { log.warning("send buffer overrun, closing connection") connection ! PoisonPill @@ -213,7 +209,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co } stay() using d.copy(sendBuffer = sendBuffer1) } else { - diag(t, "OUT") + logMessage(t, "OUT") val blob = codec.encode(t).require.toByteVector Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.OUT).record(blob.size) val (enc1, ciphertext) = d.encryptor.encrypt(blob) @@ -221,9 +217,9 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co stay() using d.copy(encryptor = enc1, unackedSent = Some(t)) } - case Event(WriteAck, d: NormalData[T @unchecked]) => - def send(t: T) = { - diag(t, "OUT") + case Event(WriteAck, d: NormalData) => + def send(t: LightningMessage) = { + logMessage(t, "OUT") val blob = codec.encode(t).require.toByteVector Monitoring.Metrics.MessageSize.withTag(Monitoring.Tags.MessageDirection, Monitoring.Tags.MessageDirections.OUT).record(blob.size) val (enc1, ciphertext) = d.encryptor.encrypt(blob) @@ -260,7 +256,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case Event(msg, d) => d match { - case n: NormalData[_] => log.warning(s"unhandled message $msg in state normal unackedSent=${n.unackedSent.size} unackedReceived=${n.unackedReceived.size} sendBuffer.lowPriority=${n.sendBuffer.lowPriority.size} sendBuffer.normalPriority=${n.sendBuffer.normalPriority.size}") + case n: NormalData => log.warning(s"unhandled message $msg in state normal unackedSent=${n.unackedSent.size} unackedReceived=${n.unackedReceived.size} sendBuffer.lowPriority=${n.sendBuffer.lowPriority.size} sendBuffer.normalPriority=${n.sendBuffer.normalPriority.size}") case _ => log.warning(s"unhandled message $msg in state ${d.getClass.getSimpleName}") } stay() @@ -273,7 +269,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co Logs.withMdc(diagLog)(Logs.mdc(category_opt = Some(Logs.LogCategory.CONNECTION), remoteNodeId_opt = remoteNodeId_opt)) { connection ! Tcp.Close // attempts to gracefully close the connection when dying stateData match { - case normal: NormalData[_] => + case normal: NormalData => // NB: we deduplicate on the class name: each class will appear once but there may be many instances (less verbose and gives debug hints) log.info("stopping (unackedReceived={} unackedSent={})", normal.unackedReceived.keys.map(getSimpleClassName).toSet.mkString(","), normal.unackedSent.map(getSimpleClassName)) case _ => @@ -297,8 +293,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co case t: Throwable => t match { // for well known crypto error, we don't display the stack trace - case _: InvalidTransportPrefix => log.error(s"crypto error: ${t.getMessage}") - case _: ChaCha20Poly1305Error => log.error(s"crypto error: ${t.getMessage}") + case _: InvalidTransportPrefix | _: ChaCha20Poly1305Error => log.error(s"crypto error: ${t.getMessage}") case _ => log.error(t, "") } throw t @@ -309,19 +304,19 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co object TransportHandler { - def props[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[T]): Props = Props(new TransportHandler(keyPair, rs, connection, codec)) + def props[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], connection: ActorRef, codec: Codec[LightningMessage]): Props = Props(new TransportHandler(keyPair, rs, connection, codec)) - val MAX_BUFFERED = 1000000L + private val MAX_BUFFERED = 1000000L // see BOLT #8 // this prefix is prepended to all Noise messages sent during the handshake phase val prefix: Byte = 0x00 + private val prologue = ByteVector.view("lightning".getBytes("UTF-8")) - case class InvalidTransportPrefix(buffer: ByteVector) extends RuntimeException(s"invalid transport prefix first64=${buffer.take(64).toHex}") - - case class UnexpectedDataDuringHandshake(buffer: ByteVector) extends RuntimeException(s"unexpected additional data received during handshake first64=${buffer.take(64).toHex}") - - val prologue = ByteVector.view("lightning".getBytes("UTF-8")) + // @formatter:off + private case class InvalidTransportPrefix(buffer: ByteVector) extends RuntimeException(s"invalid transport prefix first64=${buffer.take(64).toHex}") + private case class UnexpectedDataDuringHandshake(buffer: ByteVector) extends RuntimeException(s"unexpected additional data received during handshake first64=${buffer.take(64).toHex}") + // @formatter:on /** * See BOLT #8: during the handshake phase we are expecting 3 messages of 50, 50 and 66 bytes (including the prefix) @@ -329,17 +324,17 @@ object TransportHandler { * @param reader handshake state reader * @return the size of the message the reader is expecting */ - def expectedLength(reader: Noise.HandshakeStateReader) = reader.messages.length match { + private def expectedLength(reader: Noise.HandshakeStateReader): Int = reader.messages.length match { case 3 | 2 => 50 case 1 => 66 } - def makeWriter(localStatic: KeyPair, remoteStatic: ByteVector) = Noise.HandshakeState.initializeWriter( + private def makeWriter(localStatic: KeyPair, remoteStatic: ByteVector): HandshakeStateWriter = Noise.HandshakeState.initializeWriter( Noise.handshakePatternXK, prologue, localStatic, KeyPair(ByteVector.empty, ByteVector.empty), remoteStatic, ByteVector.empty, Noise.Secp256k1DHFunctions, Noise.Chacha20Poly1305CipherFunctions, Noise.SHA256HashFunctions) - def makeReader(localStatic: KeyPair) = Noise.HandshakeState.initializeReader( + private def makeReader(localStatic: KeyPair): HandshakeStateReader = Noise.HandshakeState.initializeReader( Noise.handshakePatternXK, prologue, localStatic, KeyPair(ByteVector.empty, ByteVector.empty), ByteVector.empty, ByteVector.empty, Noise.Secp256k1DHFunctions, Noise.Chacha20Poly1305CipherFunctions, Noise.SHA256HashFunctions) @@ -351,37 +346,32 @@ object TransportHandler { * @param ck chaining key */ case class ExtendedCipherState(cs: CipherState, ck: ByteVector) extends CipherState { - override def cipher: CipherFunctions = cs.cipher - - override def hasKey: Boolean = cs.hasKey + override val cipher: CipherFunctions = cs.cipher + override val hasKey: Boolean = cs.hasKey override def encryptWithAd(ad: ByteVector, plaintext: ByteVector): (CipherState, ByteVector) = { cs match { case UninitializedCipherState(_) => (this, plaintext) - case InitializedCipherState(k, n, _) if n == 999 => { + case InitializedCipherState(k, n, _) if n == 999 => val (_, ciphertext) = cs.encryptWithAd(ad, plaintext) val (ck1, k1) = SHA256HashFunctions.hkdf(ck, k) (this.copy(cs = cs.initializeKey(k1), ck = ck1), ciphertext) - } - case InitializedCipherState(_, n, _) => { + case _: InitializedCipherState => val (cs1, ciphertext) = cs.encryptWithAd(ad, plaintext) (this.copy(cs = cs1), ciphertext) - } } } override def decryptWithAd(ad: ByteVector, ciphertext: ByteVector): (CipherState, ByteVector) = { cs match { case UninitializedCipherState(_) => (this, ciphertext) - case InitializedCipherState(k, n, _) if n == 999 => { + case InitializedCipherState(k, n, _) if n == 999 => val (_, plaintext) = cs.decryptWithAd(ad, ciphertext) val (ck1, k1) = SHA256HashFunctions.hkdf(ck, k) (this.copy(cs = cs.initializeKey(k1), ck = ck1), plaintext) - } - case InitializedCipherState(_, n, _) => { + case _: InitializedCipherState => val (cs1, plaintext) = cs.decryptWithAd(ad, ciphertext) (this.copy(cs = cs1), plaintext) - } } } } @@ -438,23 +428,23 @@ object TransportHandler { // @formatter:off sealed trait State - case object Handshake extends State + private case object Handshake extends State case object WaitingForListener extends State case object Normal extends State sealed trait Data - case class HandshakeData(reader: Noise.HandshakeStateReader, buffer: ByteString = ByteString.empty) extends Data - case class WaitingForListenerData(encryptor: Encryptor, decryptor: Decryptor) extends Data - case class NormalData[T](encryptor: Encryptor, decryptor: Decryptor, listener: ActorRef, sendBuffer: SendBuffer[T], unackedReceived: Map[T, Int], unackedSent: Option[T]) extends Data + private case class HandshakeData(reader: Noise.HandshakeStateReader, buffer: ByteString = ByteString.empty) extends Data + private case class WaitingForListenerData(encryptor: Encryptor, decryptor: Decryptor) extends Data + private case class NormalData(encryptor: Encryptor, decryptor: Decryptor, listener: ActorRef, sendBuffer: SendBuffer, unackedReceived: Map[LightningMessage, Int], unackedSent: Option[LightningMessage]) extends Data - case class SendBuffer[T](normalPriority: Queue[T], lowPriority: Queue[T]) + private case class SendBuffer(normalPriority: Queue[LightningMessage], lowPriority: Queue[LightningMessage]) case class Listener(listener: ActorRef) case class HandshakeCompleted(remoteNodeId: PublicKey) case class ReadAck(msg: Any) extends RemoteTypes - case object WriteAck extends Tcp.Event + private case object WriteAck extends Tcp.Event // @formatter:on } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala index ffcf92cac7..8f2fe0254c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala @@ -18,8 +18,8 @@ package fr.acinq.eclair.io import akka.actor.{ActorRef, FSM, OneForOneStrategy, PoisonPill, Props, Stash, SupervisorStrategy, Terminated} import akka.event.Logging.MDC -import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32} import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair.crypto.Noise.KeyPair import fr.acinq.eclair.crypto.TransportHandler @@ -28,7 +28,7 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{FSMDiagnosticActorLogging, FeatureCompatibilityResult, Features, InitFeature, Logs, TimestampMilli, TimestampSecond} +import fr.acinq.eclair.{FSMDiagnosticActorLogging, Features, InitFeature, Logs, TimestampMilli, TimestampSecond} import scodec.Attempt import scodec.bits.ByteVector @@ -206,11 +206,13 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A stay() case Event(msg: LightningMessage, d: ConnectedData) if sender() != d.transport => // if the message doesn't originate from the transport, it is an outgoing message - d.transport forward msg + msg match { + case batch: CommitSigBatch => batch.messages.foreach(msg => d.transport forward msg) + case msg => d.transport forward msg + } msg match { // If we send any channel management message to this peer, the connection should be persistent. - case _: ChannelMessage if !d.isPersistent => - stay() using d.copy(isPersistent = true) + case _: ChannelMessage if !d.isPersistent => stay() using d.copy(isPersistent = true) case _ => stay() } @@ -341,10 +343,48 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A stay() case Event(msg: LightningMessage, d: ConnectedData) => - // we acknowledge and pass all other messages to the peer + // We immediately acknowledge all other messages. d.transport ! TransportHandler.ReadAck(msg) - d.peer ! msg - stay() + // We immediately forward messages to the peer, unless they are part of a batch, in which case we wait to + // receive the whole batch before forwarding. + msg match { + case msg: CommitSig => + msg.tlvStream.get[CommitSigTlv.BatchTlv].map(_.size) match { + case Some(batchSize) if batchSize > 25 => + log.warning("received legacy batch of commit_sig exceeding our threshold ({} > 25), processing messages individually", batchSize) + // We don't want peers to be able to exhaust our memory by sending batches of dummy messages that we keep in RAM. + d.peer ! msg + stay() + case Some(batchSize) if batchSize > 1 => + d.legacyCommitSigBatch_opt match { + case Some(pending) if pending.channelId != msg.channelId || pending.batchSize != batchSize => + log.warning("received invalid commit_sig batch while a different batch isn't complete") + // This should never happen, otherwise it will likely lead to a force-close. + d.peer ! CommitSigBatch(pending.received) + stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg)))) + case Some(pending) => + val received1 = pending.received :+ msg + if (received1.size == batchSize) { + log.debug("received last commit_sig in legacy batch for channel_id={}", msg.channelId) + d.peer ! CommitSigBatch(received1) + stay() using d.copy(legacyCommitSigBatch_opt = None) + } else { + log.debug("received commit_sig {}/{} in legacy batch for channel_id={}", received1.size, batchSize, msg.channelId) + stay() using d.copy(legacyCommitSigBatch_opt = Some(pending.copy(received = received1))) + } + case None => + log.debug("received first commit_sig in legacy batch of size {} for channel_id={}", batchSize, msg.channelId) + stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg)))) + } + case _ => + log.debug("received individual commit_sig for channel_id={}", msg.channelId) + d.peer ! msg + stay() + } + case _ => + d.peer ! msg + stay() + } case Event(readAck: TransportHandler.ReadAck, d: ConnectedData) => // we just forward acks to the transport (e.g. from the router) @@ -564,8 +604,19 @@ object PeerConnection { case class AuthenticatingData(pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport case class BeforeInitData(remoteNodeId: PublicKey, pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport case class InitializingData(chainHash: BlockHash, pendingAuth: PendingAuth, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, doSync: Boolean, isPersistent: Boolean) extends Data with HasTransport - case class ConnectedData(chainHash: BlockHash, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, rebroadcastDelay: FiniteDuration, gossipTimestampFilter: Option[GossipTimestampFilter] = None, behavior: Behavior = Behavior(), expectedPong_opt: Option[ExpectedPong] = None, isPersistent: Boolean) extends Data with HasTransport - + case class ConnectedData(chainHash: BlockHash, + remoteNodeId: PublicKey, + transport: ActorRef, + peer: ActorRef, + localInit: protocol.Init, remoteInit: protocol.Init, + rebroadcastDelay: FiniteDuration, + gossipTimestampFilter: Option[GossipTimestampFilter] = None, + behavior: Behavior = Behavior(), + expectedPong_opt: Option[ExpectedPong] = None, + legacyCommitSigBatch_opt: Option[PendingCommitSigBatch] = None, + isPersistent: Boolean) extends Data with HasTransport + + case class PendingCommitSigBatch(channelId: ByteVector32, batchSize: Int, received: Seq[CommitSig]) case class ExpectedPong(ping: Ping, timestamp: TimestampMilli = TimestampMilli.now()) case class PingTimeout(ping: Ping) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index 11034b42d0..2c6a95b43b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -427,11 +427,25 @@ case class UpdateFailMalformedHtlc(channelId: ByteVector32, failureCode: Int, tlvStream: TlvStream[UpdateFailMalformedHtlcTlv] = TlvStream.empty) extends HtlcMessage with UpdateMessage with HasChannelId with HtlcFailureMessage +/** + * [[CommitSig]] can either be sent individually or as part of a batch. When sent in a batch (which happens when there + * are pending splice transactions), we treat the whole batch as a single lightning message and group them on the wire. + */ +sealed trait CommitSigs extends HtlcMessage with HasChannelId + +object CommitSigs { + def apply(sigs: Seq[CommitSig]): CommitSigs = if (sigs.size == 1) sigs.head else CommitSigBatch(sigs) +} + case class CommitSig(channelId: ByteVector32, signature: ByteVector64, htlcSignatures: List[ByteVector64], - tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends HtlcMessage with HasChannelId { - val batchSize: Int = tlvStream.get[CommitSigTlv.BatchTlv].map(_.size).getOrElse(1) + tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs + +case class CommitSigBatch(messages: Seq[CommitSig]) extends CommitSigs { + require(messages.map(_.channelId).toSet.size == 1, "commit_sig messages in a batch must be for the same channel") + val channelId: ByteVector32 = messages.head.channelId + val batchSize: Int = messages.size } case class RevokeAndAck(channelId: ByteVector32, diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala index 48ebba08a2..3857770251 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher._ import fr.acinq.eclair.blockchain.fee.{FeeratePerKw, FeeratesPerKw} -import fr.acinq.eclair.blockchain.{DummyOnChainWallet, OnChainWallet, OnChainPubkeyCache, SingleKeyOnChainWallet} +import fr.acinq.eclair.blockchain.{DummyOnChainWallet, OnChainPubkeyCache, OnChainWallet, SingleKeyOnChainWallet} import fr.acinq.eclair.channel._ import fr.acinq.eclair.channel.fsm.Channel import fr.acinq.eclair.channel.publish.TxPublisher @@ -463,31 +463,17 @@ trait ChannelStateTestsBase extends Assertions with Eventually { val rHasChanges = r.stateData.asInstanceOf[ChannelDataWithCommitments].commitments.changes.localHasChanges s ! CMD_SIGN(Some(sender.ref)) sender.expectMsgType[RES_SUCCESS[CMD_SIGN]] - var sigs2r = 0 - var batchSize = 0 - do { - val sig = s2r.expectMsgType[CommitSig] - s2r.forward(r) - sigs2r += 1 - batchSize = sig.batchSize - } while (sigs2r < batchSize) + s2r.expectMsgType[CommitSigs] + s2r.forward(r) r2s.expectMsgType[RevokeAndAck] r2s.forward(s) - var sigr2s = 0 - do { - r2s.expectMsgType[CommitSig] - r2s.forward(s) - sigr2s += 1 - } while (sigr2s < batchSize) + r2s.expectMsgType[CommitSigs] + r2s.forward(s) s2r.expectMsgType[RevokeAndAck] s2r.forward(r) if (rHasChanges) { - sigs2r = 0 - do { - s2r.expectMsgType[CommitSig] - s2r.forward(r) - sigs2r += 1 - } while (sigs2r < batchSize) + s2r.expectMsgType[CommitSigs] + s2r.forward(r) r2s.expectMsgType[RevokeAndAck] r2s.forward(s) eventually { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala index c51e43176b..f77d84c5d6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala @@ -1517,20 +1517,14 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) alice ! CMD_SIGN() - val sigA1 = alice2bob.expectMsgType[CommitSig] - assert(sigA1.batchSize == 2) - alice2bob.forward(bob) - val sigA2 = alice2bob.expectMsgType[CommitSig] - assert(sigA2.batchSize == 2) - alice2bob.forward(bob) + val sigsA = alice2bob.expectMsgType[CommitSigBatch] + assert(sigsA.batchSize == 2) + alice2bob.forward(bob, sigsA) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - val sigB1 = bob2alice.expectMsgType[CommitSig] - assert(sigB1.batchSize == 2) - bob2alice.forward(alice) - val sigB2 = bob2alice.expectMsgType[CommitSig] - assert(sigB2.batchSize == 2) - bob2alice.forward(alice) + val sigsB = bob2alice.expectMsgType[CommitSigBatch] + assert(sigsB.batchSize == 2) + bob2alice.forward(alice, sigsB) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1546,23 +1540,20 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) alice ! CMD_SIGN() - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) + assert(alice2bob.expectMsgType[CommitSigBatch].batchSize == 2) // Bob disconnects before receiving Alice's commit_sig. disconnect(f) reconnect(f) alice2bob.expectMsgType[UpdateAddHtlc] alice2bob.forward(bob) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - alice2bob.forward(bob) - assert(alice2bob.expectMsgType[CommitSig].batchSize == 2) - alice2bob.forward(bob) + val sigsA = alice2bob.expectMsgType[CommitSigBatch] + assert(sigsA.batchSize == 2) + alice2bob.forward(bob, sigsA) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 2) - bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 2) - bob2alice.forward(alice) + val sigsB = bob2alice.expectMsgType[CommitSigBatch] + assert(sigsB.batchSize == 2) + bob2alice.forward(alice, sigsB) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1)) @@ -1679,17 +1670,16 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik alice2bob.forward(bob, spliceLockedAlice) val (preimage, htlc) = addHtlc(20_000_000 msat, alice, bob, alice2bob, bob2alice) alice ! CMD_SIGN() - val commitSigsAlice = (1 to 3).map(_ => alice2bob.expectMsgType[CommitSig]) - alice2bob.forward(bob, commitSigsAlice(0)) + val commitSigsAlice = alice2bob.expectMsgType[CommitSigBatch] + assert(commitSigsAlice.batchSize == 3) bob ! WatchPublishedTriggered(spliceTx2) val spliceLockedBob = bob2alice.expectMsgType[SpliceLocked] assert(spliceLockedBob.fundingTxId == spliceTx2.txid) bob2alice.forward(alice, spliceLockedBob) - alice2bob.forward(bob, commitSigsAlice(1)) - alice2bob.forward(bob, commitSigsAlice(2)) + alice2bob.forward(bob, commitSigsAlice) bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) - assert(bob2alice.expectMsgType[CommitSig].batchSize == 1) + bob2alice.expectMsgType[CommitSig] bob2alice.forward(alice) alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) @@ -2858,7 +2848,7 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik // Alice adds another HTLC that isn't signed by Bob. val (_, htlcOut2) = addHtlc(15_000_000 msat, alice, bob, alice2bob, bob2alice) alice ! CMD_SIGN() - alice2bob.expectMsgType[CommitSig] // Bob ignores Alice's message + alice2bob.expectMsgType[CommitSigBatch] // Bob ignores Alice's message // The first splice transaction confirms. alice ! WatchFundingConfirmedTriggered(BlockHeight(400000), 42, fundingTx1) @@ -3341,15 +3331,15 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik // Bob is waiting to sign its outgoing HTLC before sending stfu. bob2alice.expectNoMessage(100 millis) bob ! CMD_SIGN() - (0 until 3).foreach { _ => - bob2alice.expectMsgType[CommitSig] - bob2alice.forward(alice) + inside(bob2alice.expectMsgType[CommitSigBatch]) { batch => + assert(batch.batchSize == 3) + bob2alice.forward(alice, batch) } alice2bob.expectMsgType[RevokeAndAck] alice2bob.forward(bob) - (0 until 3).foreach { _ => - alice2bob.expectMsgType[CommitSig] - alice2bob.forward(bob) + inside(alice2bob.expectMsgType[CommitSigBatch]) { batch => + assert(batch.batchSize == 3) + alice2bob.forward(bob, batch) } bob2alice.expectMsgType[RevokeAndAck] bob2alice.forward(alice) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala index 0f0b2a89e2..725233c9d6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/TransportHandlerSpec.scala @@ -22,14 +22,14 @@ import akka.testkit.{TestActorRef, TestFSMRef, TestProbe} import fr.acinq.eclair.TestKitBaseClass import fr.acinq.eclair.crypto.Noise.{Chacha20Poly1305CipherFunctions, CipherState} import fr.acinq.eclair.crypto.TransportHandler.{Encryptor, ExtendedCipherState, Listener} -import fr.acinq.eclair.wire.protocol.CommonCodecs +import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{lightningMessageCodec, pingCodec} +import fr.acinq.eclair.wire.protocol.{LightningMessage, Ping, Pong} import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuiteLike import scodec.Codec import scodec.bits._ import scodec.codecs._ -import java.nio.charset.Charset import scala.annotation.tailrec import scala.concurrent.duration._ @@ -38,19 +38,19 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be import TransportHandlerSpec._ object Initiator { - val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"1111111111111111111111111111111111111111111111111111111111111111") + val s: Noise.KeyPair = Noise.Secp256k1DHFunctions.generateKeyPair(hex"1111111111111111111111111111111111111111111111111111111111111111") } object Responder { - val s = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") + val s: Noise.KeyPair = Noise.Secp256k1DHFunctions.generateKeyPair(hex"2121212121212121212121212121212121212121212121212121212121212121") } test("successful handshake") { val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, lightningMessageCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -62,43 +62,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - initiator.tell(ByteVector("hello".getBytes), probe1.ref) - probe2.expectMsg(ByteVector("hello".getBytes)) + initiator.tell(Ping(1105, ByteVector("hello".getBytes)), probe1.ref) + probe2.expectMsg(Ping(1105, ByteVector("hello".getBytes))) - responder.tell(ByteVector("bonjour".getBytes), probe2.ref) - probe1.expectMsg(ByteVector("bonjour".getBytes)) - - probe1.watch(pipe) - initiator.stop() - responder.stop() - system.stop(pipe) - probe1.expectTerminated(pipe) - } - - test("successful handshake with custom serializer") { - case class MyMessage(payload: String) - val mycodec: Codec[MyMessage] = ("payload" | scodec.codecs.string32L(Charset.defaultCharset())).as[MyMessage] - val pipe = system.actorOf(Props[MyPipe]()) - val probe1 = TestProbe() - val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, mycodec)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, mycodec)) - pipe ! (initiator, responder) - - awaitCond(initiator.stateName == TransportHandler.WaitingForListener) - awaitCond(responder.stateName == TransportHandler.WaitingForListener) - - initiator ! Listener(probe1.ref) - responder ! Listener(probe2.ref) - - awaitCond(initiator.stateName == TransportHandler.Normal) - awaitCond(responder.stateName == TransportHandler.Normal) - - initiator.tell(MyMessage("hello"), probe1.ref) - probe2.expectMsg(MyMessage("hello")) - - responder.tell(MyMessage("bonjour"), probe2.ref) - probe1.expectMsg(MyMessage("bonjour")) + responder.tell(Pong(ByteVector("bonjour".getBytes)), probe2.ref) + probe1.expectMsg(Pong(ByteVector("bonjour".getBytes))) probe1.watch(pipe) initiator.stop() @@ -108,22 +76,14 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be } test("handle unknown messages") { - sealed trait Message - case object Msg1 extends Message - case object Msg2 extends Message - - val codec1: Codec[Message] = discriminated[Message].by(uint8) - .typecase(1, provide(Msg1)) - - val codec12: Codec[Message] = discriminated[Message].by(uint8) - .typecase(1, provide(Msg1)) - .typecase(2, provide(Msg2)) + val incompleteCodec: Codec[LightningMessage] = discriminated[LightningMessage].by(uint16) + .typecase(18, pingCodec) val pipe = system.actorOf(Props[MyPipePull]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, codec1)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, codec12)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, incompleteCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -135,16 +95,18 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - responder ! Msg1 - probe1.expectMsg(Msg1) - probe1.reply(TransportHandler.ReadAck(Msg1)) + val msg1 = Ping(130, hex"deadbeef") + responder ! msg1 + probe1.expectMsg(msg1) + probe1.reply(TransportHandler.ReadAck(msg1)) - responder ! Msg2 + responder ! Pong(hex"deadbeef") probe1.expectNoMessage(2 seconds) // unknown message - responder ! Msg1 - probe1.expectMsg(Msg1) - probe1.reply(TransportHandler.ReadAck(Msg1)) + val msg2 = Ping(42, hex"beefdead") + responder ! msg2 + probe1.expectMsg(msg2) + probe1.reply(TransportHandler.ReadAck(msg2)) probe1.watch(pipe) initiator.stop() @@ -157,8 +119,8 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val pipe = system.actorOf(Props[MyPipeSplitter]()) val probe1 = TestProbe() val probe2 = TestProbe() - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, CommonCodecs.varsizebinarydata)) - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata)) + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Responder.s.pub), pipe, lightningMessageCodec)) + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec)) pipe ! (initiator, responder) awaitCond(initiator.stateName == TransportHandler.WaitingForListener) @@ -170,11 +132,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be awaitCond(initiator.stateName == TransportHandler.Normal) awaitCond(responder.stateName == TransportHandler.Normal) - initiator.tell(ByteVector("hello".getBytes), probe1.ref) - probe2.expectMsg(ByteVector("hello".getBytes)) + initiator.tell(Ping(187, ByteVector("hello".getBytes)), probe1.ref) + probe2.expectMsg(Ping(187, ByteVector("hello".getBytes))) - responder.tell(ByteVector("bonjour".getBytes), probe2.ref) - probe1.expectMsg(ByteVector("bonjour".getBytes)) + responder.tell(Pong(ByteVector("bonjour".getBytes)), probe2.ref) + probe1.expectMsg(Pong(ByteVector("bonjour".getBytes))) probe1.watch(pipe) initiator.stop() @@ -187,11 +149,11 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be val pipe = system.actorOf(Props[MyPipe]()) val probe1 = TestProbe() val supervisor = TestActorRef(Props(new MySupervisor())) - val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, CommonCodecs.varsizebinarydata), supervisor, "ini") - val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, CommonCodecs.varsizebinarydata), supervisor, "res") + val initiator = TestFSMRef(new TransportHandler(Initiator.s, Some(Initiator.s.pub), pipe, lightningMessageCodec), supervisor, "ini") + val responder = TestFSMRef(new TransportHandler(Responder.s, None, pipe, lightningMessageCodec), supervisor, "res") probe1.watch(responder) pipe ! (initiator, responder) - + // We automatically disconnect after a while if the handshake doesn't succeed. probe1.expectTerminated(responder, 3 seconds) } @@ -219,9 +181,7 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be */ val ck = hex"919219dbb2920afa8db80f9a51787a840bcf111ed8d588caf9ab4be716e42b01" val sk = hex"969ab31b4d288cedf6218839b27a3e2140827047f2c0f01bf5c04435d43511a9" - val rk = hex"bb9020b8965f4df047e07f955f3c4b88418984aadc5cdb35096b9ea8fa5c3442" val enc = ExtendedCipherState(CipherState(sk, Chacha20Poly1305CipherFunctions), ck) - val dec = ExtendedCipherState(CipherState(rk, Chacha20Poly1305CipherFunctions), ck) @tailrec def loop(cs: Encryptor, count: Int, acc: Vector[ByteVector] = Vector.empty[ByteVector]): Vector[ByteVector] = { @@ -244,15 +204,13 @@ class TransportHandlerSpec extends TestKitBaseClass with AnyFunSuiteLike with Be object TransportHandlerSpec { class MyPipe extends Actor with Stash with ActorLogging { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef): Receive = { @@ -267,15 +225,13 @@ object TransportHandlerSpec { } class MyPipeSplitter extends Actor with Stash { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef): Receive = { @@ -296,15 +252,13 @@ object TransportHandlerSpec { } class MyPipePull extends Actor with Stash { - - def receive = { + def receive: Receive = { case (a: ActorRef, b: ActorRef) => unstashAll() context watch a context watch b context become ready(a, b, aResume = true, bResume = true) - - case msg => stash() + case _ => stash() } def ready(a: ActorRef, b: ActorRef, aResume: Boolean, bResume: Boolean): Receive = { @@ -336,7 +290,7 @@ object TransportHandlerSpec { case _ => SupervisorStrategy.stop } - def receive = { + def receive: Receive = { case _ => () } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala index 4b9b1a152a..dcefce23aa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala @@ -21,8 +21,9 @@ import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32} import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional} -import fr.acinq.eclair.Features.{BasicMultiPartPayment, ChannelRangeQueries, ChannelType, PaymentSecret, StaticRemoteKey, VariableLengthOnion} +import fr.acinq.eclair.Features._ import fr.acinq.eclair.TestConstants._ +import fr.acinq.eclair.TestUtils.randomTxId import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.ConnectionDown import fr.acinq.eclair.message.OnionMessages.{Recipient, buildMessage} @@ -333,6 +334,100 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi transport.expectNoMessage(10 / transport.testKitSettings.TestTimeFactor seconds) // we don't want dilated time here } + test("send batch of commit_sig messages") { f => + import f._ + val probe = TestProbe() + connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) + val channelId = randomBytes32() + val commitSigs = Seq( + CommitSig(channelId, randomBytes64(), Nil), + CommitSig(channelId, randomBytes64(), Nil), + CommitSig(channelId, randomBytes64(), Nil), + ) + probe.send(peerConnection, CommitSigBatch(commitSigs)) + commitSigs.foreach(commitSig => transport.expectMsg(commitSig)) + transport.expectNoMessage(100 millis) + } + + test("receive legacy batch of commit_sig messages") { f => + import f._ + connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer) + + // We receive a batch of commit_sig messages from a first channel. + val channelId1 = randomBytes32() + val commitSigs1 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs1.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs1.head)) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs1.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs1.last)) + peer.expectMsg(CommitSigBatch(commitSigs1)) + + // We receive a batch of commit_sig messages from a second channel. + val channelId2 = randomBytes32() + val commitSigs2 = Seq( + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))), + ) + commitSigs2.dropRight(1).foreach(commitSig => { + transport.send(peerConnection, commitSig) + transport.expectMsg(TransportHandler.ReadAck(commitSig)) + }) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs2.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs2.last)) + peer.expectMsg(CommitSigBatch(commitSigs2)) + + // We receive another batch of commit_sig messages from the first channel, with unrelated messages in the batch. + val commitSigs3 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs3.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs3.head)) + val spliceLocked1 = SpliceLocked(channelId1, randomTxId()) + transport.send(peerConnection, spliceLocked1) + transport.expectMsg(TransportHandler.ReadAck(spliceLocked1)) + peer.expectMsg(spliceLocked1) + val spliceLocked2 = SpliceLocked(channelId2, randomTxId()) + transport.send(peerConnection, spliceLocked2) + transport.expectMsg(TransportHandler.ReadAck(spliceLocked2)) + peer.expectMsg(spliceLocked2) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs3.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs3.last)) + peer.expectMsg(CommitSigBatch(commitSigs3)) + + // We start receiving a batch of commit_sig messages from the first channel, interleaved with a batch from the second + // channel, which is not supported. + val commitSigs4 = Seq( + CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))), + ) + transport.send(peerConnection, commitSigs4.head) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4.head)) + peer.expectNoMessage(100 millis) + transport.send(peerConnection, commitSigs4(1)) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4(1))) + peer.expectMsg(CommitSigBatch(commitSigs4.take(1))) + transport.send(peerConnection, commitSigs4.last) + transport.expectMsg(TransportHandler.ReadAck(commitSigs4.last)) + peer.expectMsg(CommitSigBatch(commitSigs4.tail)) + + // We receive a batch that exceeds our threshold: we process them individually. + val invalidCommitSigs = (0 until 30).map(_ => CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(30)))) + invalidCommitSigs.foreach(commitSig => { + transport.send(peerConnection, commitSig) + transport.expectMsg(TransportHandler.ReadAck(commitSig)) + peer.expectMsg(commitSig) + }) + } + test("react to peer's bad behavior") { f => import f._ val probe = TestProbe()