diff --git a/chain-signatures/node/src/cli.rs b/chain-signatures/node/src/cli.rs index 07ec85d2..23090b6a 100644 --- a/chain-signatures/node/src/cli.rs +++ b/chain-signatures/node/src/cli.rs @@ -5,6 +5,7 @@ use crate::mesh::Mesh; use crate::node_client::{self, NodeClient}; use crate::protocol::message::MessageChannel; use crate::protocol::presignature::Presignature; +use crate::protocol::signature::SignatureSpawnerTask; use crate::protocol::state::Node; use crate::protocol::sync::SyncTask; use crate::protocol::{spawn_system_metrics, MpcSignProtocol}; @@ -22,8 +23,7 @@ use mpc_keys::hpke; use near_account_id::AccountId; use near_crypto::{InMemorySigner, PublicKey, SecretKey}; use sha3::Digest; -use std::sync::Arc; -use tokio::sync::{mpsc, watch, RwLock}; +use tokio::sync::{mpsc, watch}; use url::Url; const DEFAULT_WEB_PORT: u16 = 3000; @@ -315,21 +315,34 @@ pub async fn run(cmd: Cli) -> anyhow::Result<()> { contract_watcher.clone(), ) .await; + + // Start the signature spawner task immediately. + // It will wait for the contract to reach Running state before processing requests, + // and dynamically obtains governance info from the contract state. + let sign_task = SignatureSpawnerTask::run( + account_id.clone(), + sign_rx, + contract_watcher.clone(), + config_rx.clone(), + presignature_storage.clone(), + mesh_state.clone(), + msg_channel.clone(), + rpc_channel.clone(), + backlog.clone(), + ); + let protocol = MpcSignProtocol { my_account_id: account_id.clone(), - rpc_channel, msg_channel: msg_channel.clone(), generating: msg_channel.subscribe_generation().await, resharing: msg_channel.subscribe_resharing().await, ready: msg_channel.subscribe_ready().await, - sign_rx: Arc::new(RwLock::new(sign_rx)), + sign_task, secret_storage: key_storage, triple_storage: triple_storage.clone(), presignature_storage: presignature_storage.clone(), - contract: contract_watcher.clone(), config: config_rx, mesh_state: mesh_state.clone(), - backlog: backlog.clone(), }; tracing::info!("protocol initialized"); diff --git a/chain-signatures/node/src/protocol/consensus.rs b/chain-signatures/node/src/protocol/consensus.rs index bce96aa8..5733fc05 100644 --- a/chain-signatures/node/src/protocol/consensus.rs +++ b/chain-signatures/node/src/protocol/consensus.rs @@ -6,7 +6,6 @@ use super::state::{ use super::MpcSignProtocol; use crate::protocol::contract::primitives::Participants; use crate::protocol::presignature::PresignatureSpawnerTask; -use crate::protocol::signature::SignatureSpawnerTask; use crate::protocol::state::GeneratingState; use crate::protocol::triple::TripleSpawnerTask; use crate::protocol::Governance; @@ -107,14 +106,6 @@ impl ConsensusProtocol for StartedState { &public_key, ); - let sign_task = SignatureSpawnerTask::run( - me, - contract_state.threshold, - epoch, - ctx, - public_key, - ); - NodeState::Running(RunningState { epoch, me, @@ -124,7 +115,6 @@ impl ConsensusProtocol for StartedState { public_key, triple_task, presign_task, - sign_task, }) } } @@ -408,13 +398,6 @@ impl ConsensusProtocol for WaitingForConsensusState { &self.private_share, &self.public_key, ); - let sign_task = SignatureSpawnerTask::run( - me, - self.threshold, - self.epoch, - ctx, - self.public_key, - ); NodeState::Running(RunningState { epoch: self.epoch, @@ -425,7 +408,6 @@ impl ConsensusProtocol for WaitingForConsensusState { public_key: self.public_key, triple_task, presign_task, - sign_task, }) } }, diff --git a/chain-signatures/node/src/protocol/mod.rs b/chain-signatures/node/src/protocol/mod.rs index 58d6b361..dc3dd06e 100644 --- a/chain-signatures/node/src/protocol/mod.rs +++ b/chain-signatures/node/src/protocol/mod.rs @@ -21,14 +21,14 @@ pub use mpc_primitives::Chain; pub use signature::{IndexedSignRequest, Sign}; pub use state::{Node, NodeState}; -use crate::backlog::Backlog; use crate::config::Config; use crate::mesh::MeshState; use crate::protocol::consensus::ConsensusProtocol; use crate::protocol::cryptography::CryptographicProtocol; use crate::protocol::message::{GeneratingMessage, ReadyMessage, ResharingMessage}; +use crate::protocol::signature::SignatureSpawnerTask; use crate::respond_bidirectional::RespondBidirectionalTx; -use crate::rpc::{ContractStateWatcher, RpcChannel}; +use crate::rpc::ContractStateWatcher; use crate::storage::presignature_storage::PresignatureStorage; use crate::storage::secret_storage::SecretNodeStorageBox; use crate::storage::triple_storage::TripleStorage; @@ -36,10 +36,8 @@ use crate::storage::triple_storage::TripleStorage; use near_account_id::AccountId; use semver::Version; use std::path::Path; -use std::sync::Arc; use std::time::{Duration, Instant}; use sysinfo::{CpuRefreshKind, Disks, RefreshKind, System}; -use tokio::sync::RwLock; use tokio::sync::{mpsc, watch}; pub struct MpcSignProtocol { @@ -47,16 +45,19 @@ pub struct MpcSignProtocol { pub(crate) secret_storage: SecretNodeStorageBox, pub(crate) triple_storage: TripleStorage, pub(crate) presignature_storage: PresignatureStorage, - pub(crate) sign_rx: Arc>>, + pub(crate) sign_task: SignatureSpawnerTask, pub(crate) generating: mpsc::Receiver, pub(crate) resharing: mpsc::Receiver, pub(crate) ready: mpsc::Receiver, pub(crate) msg_channel: MessageChannel, - pub(crate) rpc_channel: RpcChannel, - pub(crate) contract: ContractStateWatcher, pub(crate) config: watch::Receiver, pub(crate) mesh_state: watch::Receiver, - pub(crate) backlog: Backlog, +} + +impl Drop for MpcSignProtocol { + fn drop(&mut self) { + self.sign_task.abort(); + } } /// Interface required by the [`MpcSignProtocol`] to participate in the diff --git a/chain-signatures/node/src/protocol/signature.rs b/chain-signatures/node/src/protocol/signature.rs index 1b217b33..5361c8e7 100644 --- a/chain-signatures/node/src/protocol/signature.rs +++ b/chain-signatures/node/src/protocol/signature.rs @@ -1,4 +1,3 @@ -use super::MpcSignProtocol; use crate::backlog::Backlog; use crate::config::Config; use crate::kdf::derive_delta; @@ -10,27 +9,26 @@ use crate::protocol::message::{ use crate::protocol::posit::{PositAction, SinglePositCounter}; use crate::protocol::presignature::PresignatureId; use crate::protocol::Chain; -use crate::rpc::{ContractStateWatcher, RpcChannel}; +use crate::protocol::{ProtocolState, SignRequestType}; +use crate::rpc::{ContractStateWatcher, GovernanceInfo, RpcChannel}; use crate::storage::presignature_storage::{PresignatureTaken, PresignatureTakenDropper}; use crate::storage::PresignatureStorage; use crate::types::SignatureProtocol; use crate::util::{AffinePointExt, JoinMap, TimeoutBudget}; -use crate::protocol::SignRequestType; use cait_sith::protocol::{Action, InitializationError, Participant}; use cait_sith::PresignOutput; use chrono::Utc; use k256::Secp256k1; use mpc_contract::config::ProtocolConfig; -use mpc_crypto::{derive_key, PublicKey}; +use mpc_crypto::derive_key; use mpc_primitives::{SignArgs, SignId}; use rand::rngs::StdRng; use rand::seq::IteratorRandom; use rand::SeedableRng; use std::collections::{BTreeSet, HashMap, VecDeque}; -use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::{mpsc, watch, RwLock}; +use tokio::sync::{mpsc, watch}; use tokio::task::JoinHandle; use near_account_id::AccountId; @@ -157,19 +155,57 @@ enum SignPhase { Complete(Result<(), SignError>), } +/// Context passed to phase advancement - contains immutable references to task resources +struct SignContext<'a> { + gov: &'a GovernanceInfo, + sign_id: SignId, + my_account_id: &'a AccountId, + presignatures: &'a PresignatureStorage, + msg: &'a MessageChannel, + rpc: &'a RpcChannel, + cfg: &'a ProtocolConfig, +} + +impl<'a> SignContext<'a> { + fn me(&self) -> Participant { + self.gov.me + } + + fn threshold(&self) -> usize { + self.gov.threshold + } + + fn epoch(&self) -> u64 { + self.gov.epoch + } + + fn public_key(&self) -> mpc_crypto::PublicKey { + self.gov.public_key + } + + fn participants(&self) -> &BTreeSet { + &self.gov.participants + } +} + impl SignPhase { async fn advance( - self, - ctx: &SignTask, + &mut self, + ctx: &SignContext<'_>, state: &mut SignState, task_rx: &mut mpsc::Receiver, - ) -> SignPhase { - match self { + ) { + // Use take pattern to extract the current phase and replace with a temporary Complete(Ok(())) + // This allows us to call advance methods that consume self while still allowing &mut self here + let current = std::mem::replace(self, SignPhase::Complete(Ok(()))); + + let next = match current { SignPhase::Organizing(phase) => phase.advance(ctx, state).await, SignPhase::Posit(phase) => phase.advance(ctx, state, task_rx).await, SignPhase::Generating(phase) => phase.advance(ctx, state).await, SignPhase::Complete(result) => SignPhase::Complete(result), - } + }; + *self = next; } } @@ -188,7 +224,7 @@ impl SignOrganizer { /// Waits for threshold stable participants to be present. async fn wait_stable( &self, - ctx: &SignTask, + ctx: &SignContext<'_>, state: &mut SignState, threshold: usize, ) -> Option> { @@ -220,12 +256,12 @@ impl SignOrganizer { } } - async fn advance(self, ctx: &SignTask, state: &mut SignState) -> SignPhase { + async fn advance(self, ctx: &SignContext<'_>, state: &mut SignState) -> SignPhase { let sign_id = ctx.sign_id; - let threshold = ctx.threshold; - let me = ctx.me; + let threshold = ctx.threshold(); + let me = ctx.me(); let entropy = state.indexed.args.entropy; - let participants = ctx.participants.iter().copied().collect::>(); + let participants = ctx.participants().iter().copied().collect::>(); tracing::info!(?sign_id, round = ?state.round, "entering organizing phase"); let (stable, proposer) = { @@ -271,7 +307,7 @@ impl SignOrganizer { (stable, proposer) }; - let is_proposer = proposer == ctx.me; + let is_proposer = proposer == ctx.me(); let (presignature_id, presignature, stable) = if is_proposer { tracing::info!(?sign_id, round = ?state.round, "proposer waiting for presignature"); let stable = stable.iter().copied().collect::>(); @@ -279,9 +315,9 @@ impl SignOrganizer { let remaining = state.budget.remaining(); let fetch = tokio::time::timeout(remaining, async { loop { - if let Some(taken) = ctx.presignatures.take_mine(ctx.me).await { + if let Some(taken) = ctx.presignatures.take_mine(ctx.me()).await { let participants = intersect_vec(&[&taken.artifact.participants, &stable]); - if participants.len() < ctx.threshold { + if participants.len() < ctx.threshold() { recycle.push(taken); continue; } @@ -319,16 +355,16 @@ impl SignOrganizer { // broadcast to participants and let them reject if they don't have the presignature. for &p in &participants { - if p == ctx.me { + if p == ctx.me() { continue; } ctx.msg .send( - ctx.me, + ctx.me(), p, PositMessage { id: PositProtocolId::Signature(sign_id, presignature_id, state.round), - from: ctx.me, + from: ctx.me(), action: PositAction::Propose, }, ) @@ -354,7 +390,7 @@ impl SignOrganizer { impl SignPositor { /// Deliberator waits for the proposer to send a Propose message with a presignature_id. async fn wait_propose( - ctx: &SignTask, + ctx: &SignContext<'_>, state: &mut SignState, task_rx: &mut mpsc::Receiver, proposer: Participant, @@ -392,7 +428,7 @@ impl SignPositor { if state.round > *peer_round { ctx.msg .send( - ctx.me, + ctx.me(), proposer, PositMessage { id: PositProtocolId::Signature( @@ -400,7 +436,7 @@ impl SignPositor { *presignature_id, *peer_round, ), - from: ctx.me, + from: ctx.me(), action: PositAction::Reject, }, ) @@ -421,7 +457,7 @@ impl SignPositor { continue; } - if from == &proposer { + if *from == proposer { tracing::info!( ?sign_id, presignature_id, @@ -438,7 +474,7 @@ impl SignPositor { ); ctx.msg .send( - ctx.me, + ctx.me(), proposer, PositMessage { id: PositProtocolId::Signature( @@ -446,7 +482,7 @@ impl SignPositor { *presignature_id, state.round, ), - from: ctx.me, + from: ctx.me(), action: PositAction::Reject, }, ) @@ -465,7 +501,7 @@ impl SignPositor { ctx.msg .send( - ctx.me, + ctx.me(), *from, PositMessage { id: PositProtocolId::Signature( @@ -473,7 +509,7 @@ impl SignPositor { *presignature_id, state.round, ), - from: ctx.me, + from: ctx.me(), action: PositAction::Reject, }, ) @@ -490,7 +526,7 @@ impl SignPositor { ?sign_id, ?round, ?proposer, - me=?ctx.me, + me = ?ctx.me(), "deliberator timeout waiting for Propose, reorganizing" ); state.bump_round(); @@ -501,11 +537,11 @@ impl SignPositor { // received propose, send Accept ctx.msg .send( - ctx.me, + ctx.me(), proposer, PositMessage { id: PositProtocolId::Signature(sign_id, presignature_id, state.round), - from: ctx.me, + from: ctx.me(), action: PositAction::Accept, }, ) @@ -516,7 +552,7 @@ impl SignPositor { async fn advance( self, - ctx: &SignTask, + ctx: &SignContext<'_>, state: &mut SignState, task_rx: &mut mpsc::Receiver, ) -> SignPhase { @@ -529,7 +565,7 @@ impl SignPositor { let sign_id = ctx.sign_id; let round = state.round; - let is_proposer = proposer == ctx.me; + let is_proposer = proposer == ctx.me(); let is_deliberator = !is_proposer; // Get the presignature participants - only these nodes participated in generating it @@ -564,7 +600,7 @@ impl SignPositor { // GUARANTEE: at least threshold participants from organizing phase. let posit_participants = stable.iter().copied().collect::>(); - let mut counter = SinglePositCounter::new(ctx.me, &posit_participants); + let mut counter = SinglePositCounter::new(ctx.me(), &posit_participants); let remaining = state.budget.remaining(); let posit_deadline = tokio::time::sleep(remaining); @@ -599,7 +635,7 @@ impl SignPositor { continue; } - if participants.len() < ctx.threshold { + if participants.len() < ctx.threshold() { tracing::warn!( ?sign_id, ?round, @@ -609,7 +645,7 @@ impl SignPositor { return SignPhase::Organizing(SignOrganizer); } - tracing::info!(?sign_id, participant = ?ctx.me, ?participants, "deliberator received Start"); + tracing::info!(?sign_id, participant = ?ctx.me(), ?participants, "deliberator received Start"); break participants; } } else { @@ -617,11 +653,11 @@ impl SignPositor { continue; } - if counter.enough_rejects(ctx.threshold) { + if counter.enough_rejects(ctx.threshold()) { tracing::warn!(?sign_id, ?from, "received enough REJECTs, reorganizing"); if let Some(taken) = presignature { tracing::warn!(?sign_id, "recycling presignature due to REJECTs"); - ctx.presignatures.recycle_mine(ctx.me, taken).await; + ctx.presignatures.recycle_mine(ctx.me(), taken).await; } state.bump_round(); return SignPhase::Organizing(SignOrganizer); @@ -634,35 +670,35 @@ impl SignPositor { participants.retain(|p| presignature_participants.contains(p)); } - if participants.len() < ctx.threshold { + if participants.len() < ctx.threshold() { tracing::warn!( ?sign_id, presig_participants = ?presignature_participants, accepts = ?counter.accepts, filtered_participants = ?participants, - threshold = ctx.threshold, + threshold = ctx.threshold(), "not enough presignature participants accepted, reorganizing" ); if let Some(taken) = presignature { tracing::warn!(?sign_id, "recycling presignature due to insufficient participants"); - ctx.presignatures.recycle_mine(ctx.me, taken).await; + ctx.presignatures.recycle_mine(ctx.me(), taken).await; } state.bump_round(); return SignPhase::Organizing(SignOrganizer); } - tracing::info!(?sign_id, me = ?ctx.me, ?participants, "proposer broadcasting Start"); + tracing::info!(?sign_id, me = ?ctx.me(), ?participants, "proposer broadcasting Start"); for &p in &participants { - if p == ctx.me { + if p == ctx.me() { continue; } ctx.msg .send( - ctx.me, + ctx.me(), p, PositMessage { id: PositProtocolId::Signature(sign_id, presignature_id, state.round), - from: ctx.me, + from: ctx.me(), action: PositAction::Start(participants.clone()), }, ) @@ -674,25 +710,25 @@ impl SignPositor { } _ = &mut posit_deadline => { if is_proposer { - if counter.enough_accepts(ctx.threshold) { + if counter.enough_accepts(ctx.threshold()) { // Only include participants who both accepted AND were part of the presignature generation let mut participants = counter.accepts.iter().copied().collect::>(); if !presignature_participants.is_empty() { participants.retain(|p| presignature_participants.contains(p)); } - if participants.len() < ctx.threshold { + if participants.len() < ctx.threshold() { tracing::warn!( ?sign_id, presig_participants = ?presignature_participants, accepts = ?counter.accepts, filtered_participants = ?participants, - threshold = ctx.threshold, + threshold = ctx.threshold(), "posit timeout: not enough presignature participants accepted, reorganizing" ); if let Some(taken) = presignature { tracing::warn!(?sign_id, "recycling presignature due to posit timeout"); - ctx.presignatures.recycle_mine(ctx.me, taken).await; + ctx.presignatures.recycle_mine(ctx.me(), taken).await; } state.bump_round(); return SignPhase::Organizing(SignOrganizer); @@ -700,16 +736,16 @@ impl SignPositor { tracing::info!(?sign_id, "posit timeout with enough accepts, broadcasting Start"); for &p in &participants { - if p == ctx.me { + if p == ctx.me() { continue; } ctx.msg .send( - ctx.me, + ctx.me(), p, PositMessage { id: PositProtocolId::Signature(sign_id, presignature_id, state.round), - from: ctx.me, + from: ctx.me(), action: PositAction::Start(participants.clone()), }, ) @@ -719,12 +755,13 @@ impl SignPositor { } else { tracing::warn!( ?sign_id, - accepts=counter.accepts.len(), - threshold=ctx.threshold, - "posit timeout without enough accepts, reorganizing"); + accepts = counter.accepts.len(), + threshold = ctx.threshold(), + "posit timeout without enough accepts, reorganizing", + ); if let Some(taken) = presignature { tracing::warn!(?sign_id, "recycling presignature due to posit timeout (no accepts)"); - ctx.presignatures.recycle_mine(ctx.me, taken).await; + ctx.presignatures.recycle_mine(ctx.me(), taken).await; } state.bump_round(); return SignPhase::Organizing(SignOrganizer); @@ -748,7 +785,7 @@ impl SignPositor { } impl SignGenerating { - async fn advance(mut self, ctx: &SignTask, state: &mut SignState) -> SignPhase { + async fn advance(mut self, ctx: &SignContext<'_>, state: &mut SignState) -> SignPhase { let sign_id = ctx.sign_id; let round = state.round; @@ -803,7 +840,7 @@ impl SignGenerating { ?sign_id, ?round, ?err, - me=?ctx.me, + me = ?ctx.me(), "signature generation failed, reorganizing" ); state.bump_round(); @@ -831,7 +868,7 @@ struct SignGenerator { impl SignGenerator { async fn new( - ctx: &SignTask, + ctx: &SignContext<'_>, proposer: Participant, indexed: IndexedSignRequest, presignature: PendingPresignature, @@ -849,7 +886,7 @@ impl SignGenerator { let sign_id = indexed.id; tracing::info!( - me = ?ctx.me, + me = ?ctx.me(), ?sign_id, presignature_id, "starting protocol to generate a new signature", @@ -866,8 +903,8 @@ impl SignGenerator { }; let protocol = Box::new(cait_sith::sign( &participants, - ctx.me, - derive_key(ctx.public_key, indexed.args.epsilon), + ctx.me(), + derive_key(ctx.public_key(), indexed.args.epsilon), output, indexed.args.payload, )?); @@ -912,10 +949,10 @@ impl SignGenerator { } } - async fn run(mut self, ctx: &SignTask) -> Result<(), SignError> { + async fn run(mut self, ctx: &SignContext<'_>) -> Result<(), SignError> { let my_account_id = &ctx.my_account_id; - let me = ctx.me; - let epoch = ctx.epoch; + let me = ctx.me(); + let epoch = ctx.epoch(); let accrued_wait_delay = crate::metrics::protocols::SIGNATURE_ACCRUED_WAIT_DELAY .with_label_values(&[my_account_id.as_str()]); @@ -1032,7 +1069,7 @@ impl SignGenerator { if self.proposer == me { ctx.rpc.publish( - ctx.public_key, + ctx.public_key(), self.indexed.clone(), output, self.participants.clone(), @@ -1081,16 +1118,13 @@ impl Drop for SignGenerator { } struct SignTask { - me: Participant, - participants: BTreeSet, + gov: GovernanceInfo, sign_id: SignId, - threshold: usize, - public_key: PublicKey, - epoch: u64, my_account_id: AccountId, presignatures: PresignatureStorage, msg: MessageChannel, rpc: RpcChannel, + contract: ContractStateWatcher, // TODO: will be used in the future when we move requests channels // into the backlog. @@ -1098,22 +1132,20 @@ struct SignTask { backlog: Backlog, cfg: ProtocolConfig, - contract: ContractStateWatcher, } impl SignTask { async fn run( - self, + mut self, indexed: IndexedSignRequest, mesh_state: watch::Receiver, mut task_rx: mpsc::Receiver, ) -> Result<(), SignError> { let sign_id = self.sign_id; - let task_epoch = self.epoch; tracing::info!( ?sign_id, - me = ?self.me, - epoch = task_epoch, + me = ?self.gov.me, + epoch = self.gov.epoch, "signature task starting with organizing loop" ); @@ -1121,35 +1153,85 @@ impl SignTask { let mut phase = SignPhase::Organizing(SignOrganizer); loop { - // Check if we should abort due to resharing or epoch change - if let Some(contract_state) = self.contract.state() { - match contract_state { - crate::protocol::ProtocolState::Resharing(_) => { - tracing::info!( - ?sign_id, - epoch = task_epoch, - "signature task interrupted: contract is resharing" - ); - return Err(SignError::Aborted); - } - crate::protocol::ProtocolState::Running(running) - if running.epoch != task_epoch => - { - tracing::info!( - ?sign_id, - old_epoch = task_epoch, - new_epoch = running.epoch, - "signature task interrupted: epoch changed" - ); + // Create context before select to avoid borrow issues + let ctx = SignContext { + gov: &self.gov, + sign_id: self.sign_id, + my_account_id: &self.my_account_id, + presignatures: &self.presignatures, + msg: &self.msg, + rpc: &self.rpc, + cfg: &self.cfg, + }; + + tokio::select! { + // Contract state changes take priority - cancel ongoing work + biased; + contract = self.contract.next_state() => { + let Some(contract) = contract else { + tracing::warn!(?sign_id, "contract state channel closed"); return Err(SignError::Aborted); + }; + + match contract { + ProtocolState::Running(running) => { + // Update governance info if epoch changed + if running.epoch != self.gov.epoch { + if let Some(&me) = running.participants.find_participant(&self.my_account_id) { + tracing::info!( + ?sign_id, + old_epoch = self.gov.epoch, + new_epoch = running.epoch, + ?me, + "sign task: epoch changed, updating governance info" + ); + self.gov = GovernanceInfo { + me, + threshold: running.threshold, + epoch: running.epoch, + public_key: running.public_key, + participants: running.participants.keys().copied().collect(), + }; + // Reset to organizing phase with new governance + state.round = 0; + phase = SignPhase::Organizing(SignOrganizer); + } else { + tracing::warn!( + ?sign_id, + epoch = running.epoch, + "sign task: we are no longer a participant after epoch change" + ); + return Err(SignError::Aborted); + } + } + // If same epoch and running, just continue + } + ProtocolState::Initializing(_) | ProtocolState::Resharing(_) => { + tracing::info!( + ?sign_id, + epoch = self.gov.epoch, + "sign task: contract entered non-running state, waiting for running" + ); + // Wait for contract to return to Running state + self.gov = self.contract.wait_governance().await; + tracing::info!( + ?sign_id, + new_epoch = self.gov.epoch, + "sign task: contract returned to running, resuming" + ); + // Reset to organizing phase + state.round = 0; + phase = SignPhase::Organizing(SignOrganizer); + } } - _ => {} } - } - phase = match phase.advance(&self, &mut state, &mut task_rx).await { - SignPhase::Complete(result) => return result, - other => other, + // This will be cancelled when the contract state changes. + _ = phase.advance(&ctx, &mut state, &mut task_rx) => { + if let SignPhase::Complete(result) = phase { + return result; + } + } } } } @@ -1157,6 +1239,7 @@ impl SignTask { /// Message types that can be sent to a running signature task enum SignTaskMessage { + /// Posit message from another node PositMessage { presignature_id: PresignatureId, round: usize, @@ -1173,12 +1256,9 @@ pub struct SignatureSpawner { /// Buffered inboxes for posit messages, allowing us to queue before tasks spawn inboxes: HashMap>, mesh_state: watch::Receiver, + contract: ContractStateWatcher, - me: Participant, my_account_id: AccountId, - threshold: usize, - public_key: PublicKey, - epoch: u64, msg: MessageChannel, rpc: RpcChannel, backlog: Backlog, @@ -1190,8 +1270,7 @@ impl SignatureSpawner { fn spawn_task( &mut self, indexed: IndexedSignRequest, - participants: BTreeSet, - contract: ContractStateWatcher, + gov: &GovernanceInfo, cfg: ProtocolConfig, ) { let sign_id = indexed.id; @@ -1200,19 +1279,15 @@ impl SignatureSpawner { // Subscribe to (or create) the posit inbox for this sign request let rx = self.inboxes.entry(sign_id).or_default().subscribe(); let task = SignTask { - me: self.me, - participants, + gov: gov.clone(), sign_id, - threshold: self.threshold, - public_key: self.public_key, - epoch: self.epoch, my_account_id: self.my_account_id.clone(), presignatures: self.presignatures.clone(), msg: self.msg.clone(), rpc: self.rpc.clone(), + contract: self.contract.clone(), backlog: self.backlog.clone(), cfg, - contract, }; // Spawn the async task with organizing loop @@ -1228,9 +1303,10 @@ impl SignatureSpawner { round: usize, from: Participant, action: PositAction, + gov: &GovernanceInfo, ) { // Ignore messages from ourselves - if from == self.me { + if from == gov.me { return; } let _ = self @@ -1255,13 +1331,7 @@ impl SignatureSpawner { } } - fn handle_request( - &mut self, - sign: Sign, - cfg: &ProtocolConfig, - participants: &BTreeSet, - contract: &ContractStateWatcher, - ) { + fn handle_request(&mut self, sign: Sign, cfg: &ProtocolConfig, gov: &GovernanceInfo) { match sign { Sign::Completion(sign_id) => { self.handle_completion(sign_id); @@ -1282,7 +1352,7 @@ impl SignatureSpawner { .with_label_values(&[indexed.chain.as_str(), self.my_account_id.as_str()]) .inc(); - self.spawn_task(indexed, participants.clone(), contract.clone(), cfg.clone()); + self.spawn_task(indexed, gov, cfg.clone()); } } @@ -1294,19 +1364,22 @@ impl SignatureSpawner { async fn run( mut self, - sign_rx: Arc>>, + mut sign_rx: mpsc::Receiver, mut contract: ContractStateWatcher, mut cfg: watch::Receiver, ) { let mut posits = self.msg.subscribe_signature_posit().await; - let running = contract.wait_running().await; - let all_participants = running.participants.keys().copied().collect(); - let mut protocol = cfg.borrow().protocol.clone(); + // Wait for initial governance info + let mut gov = contract.wait_governance().await; + tracing::info!( + me = ?gov.me, + epoch = gov.epoch, + threshold = gov.threshold, + "signature spawner initialized with governance info" + ); - // we acquire the lock but since this is a tokio lock, aborting the task while holding - // the lock is safe and will not deadlock other tasks trying to acquire the lock - let mut sign_rx = sign_rx.write().await; + let mut protocol = cfg.borrow().protocol.clone(); loop { tokio::select! { @@ -1315,10 +1388,10 @@ impl SignatureSpawner { tracing::warn!("signature spawner sign_rx closed, terminating"); break; }; - self.handle_request(sign, &protocol, &all_participants, &contract); + self.handle_request(sign, &protocol, &gov); } Some((sign_id, presignature_id, round, from, action)) = posits.recv() => { - self.handle_posit(sign_id, presignature_id, round, from, action).await; + self.handle_posit(sign_id, presignature_id, round, from, action, &gov).await; } Some(result) = self.tasks.join_next(), if !self.tasks.is_empty() => { let (sign_id, result) = match result { @@ -1343,6 +1416,50 @@ impl SignatureSpawner { Ok(()) = cfg.changed() => { protocol = cfg.borrow().protocol.clone(); } + _ = contract.next_state() => { + // Contract state changed - check if we need to update governance info + if let Some(state) = contract.state() { + match state { + ProtocolState::Running(running) => { + if running.epoch != gov.epoch { + // Epoch changed, update governance info + if let Some(me) = running.participants.find_participant(contract.my_account_id()) { + tracing::info!( + old_epoch = gov.epoch, + new_epoch = running.epoch, + new_me = ?me, + "signature spawner: epoch changed, updating governance info" + ); + gov = GovernanceInfo { + me: *me, + threshold: running.threshold, + epoch: running.epoch, + public_key: running.public_key, + participants: running.participants.keys().copied().collect(), + }; + // Tasks will notice the state change and update themselves + } else { + tracing::warn!( + epoch = running.epoch, + "signature spawner: we are no longer a participant after epoch change" + ); + // Tasks will notice and wait_governance will block until we're a participant again + } + } + } + ProtocolState::Resharing(_) => { + tracing::info!( + epoch = gov.epoch, + "signature spawner: contract is resharing, tasks will pause themselves" + ); + // Tasks will notice the state change and pause themselves + } + ProtocolState::Initializing(_) => { + tracing::debug!("signature spawner: contract is initializing"); + } + } + } + } } } } @@ -1360,51 +1477,52 @@ pub struct SignatureSpawnerTask { } impl SignatureSpawnerTask { + /// Start the signature spawner task. + /// + /// The spawner will wait for the contract to reach Running state before processing requests. + /// It dynamically obtains governance info (me, threshold, epoch, public_key) from the contract + /// state and updates when the contract transitions to a new epoch after resharing. + #[allow(clippy::too_many_arguments)] pub fn run( - me: Participant, - threshold: usize, - epoch: u64, - ctx: &MpcSignProtocol, - public_key: PublicKey, + my_account_id: AccountId, + sign_rx: mpsc::Receiver, + contract: ContractStateWatcher, + config: watch::Receiver, + presignature_storage: PresignatureStorage, + mesh_state: watch::Receiver, + msg_channel: MessageChannel, + rpc_channel: RpcChannel, + backlog: Backlog, ) -> Self { let spawner = SignatureSpawner { - me, tasks: JoinMap::new(), inboxes: HashMap::new(), - my_account_id: ctx.my_account_id.clone(), - threshold, - public_key, - epoch, - presignatures: ctx.presignature_storage.clone(), - mesh_state: ctx.mesh_state.clone(), - msg: ctx.msg_channel.clone(), - rpc: ctx.rpc_channel.clone(), - backlog: ctx.backlog.clone(), + my_account_id, + presignatures: presignature_storage, + mesh_state, + contract: contract.clone(), + msg: msg_channel, + rpc: rpc_channel, + backlog, }; Self { - handle: tokio::spawn(spawner.run( - ctx.sign_rx.clone(), - ctx.contract.clone(), - ctx.config.clone(), - )), + handle: tokio::spawn(spawner.run(sign_rx, contract, config)), } } pub fn abort(&self) { - // NOTE: since dropping the handle here, PresignatureSpawner will drop their JoinSet/JoinMap - // which will also abort all ongoing presignature generation tasks. This is important to note - // since we do not want to leak any presignature generation tasks when we are resharing, and - // potentially wasting compute. + // NOTE: This aborts the spawner task and all ongoing signature generation tasks. + // This is typically only called during shutdown or catastrophic error recovery. + // During normal operation (including resharing), the spawner should NOT be aborted + // since sign tasks need to persist across resharing. self.handle.abort(); } } -impl Drop for SignatureSpawnerTask { - fn drop(&mut self) { - self.abort(); - } -} +// NOTE: We intentionally do NOT implement Drop for SignatureSpawnerTask. +// The spawner should persist across resharing, so dropping it should not abort the task. +// The spawner will naturally terminate when all channels close (e.g., during node shutdown). enum PendingPresignature { Available(PresignatureTaken), diff --git a/chain-signatures/node/src/protocol/state.rs b/chain-signatures/node/src/protocol/state.rs index c3f1fe0b..d3d93e7f 100644 --- a/chain-signatures/node/src/protocol/state.rs +++ b/chain-signatures/node/src/protocol/state.rs @@ -1,7 +1,6 @@ use super::contract::{primitives::Participants, ResharingContractState}; use super::triple::TripleSpawnerTask; use crate::protocol::presignature::PresignatureSpawnerTask; -use crate::protocol::signature::SignatureSpawnerTask; use crate::types::{KeygenProtocol, ReshareProtocol, SecretKeyShare}; use cait_sith::protocol::Participant; @@ -77,7 +76,6 @@ pub struct RunningState { pub public_key: PublicKey, pub triple_task: TripleSpawnerTask, pub presign_task: PresignatureSpawnerTask, - pub sign_task: SignatureSpawnerTask, } pub struct ResharingState { diff --git a/chain-signatures/node/src/protocol/test_setup.rs b/chain-signatures/node/src/protocol/test_setup.rs index 1847e9c7..794f2dd4 100644 --- a/chain-signatures/node/src/protocol/test_setup.rs +++ b/chain-signatures/node/src/protocol/test_setup.rs @@ -1,14 +1,13 @@ -use std::sync::Arc; - use crate::backlog::Backlog; use crate::config::Config; use crate::mesh::MeshState; +use crate::protocol::signature::SignatureSpawnerTask; use crate::protocol::{MessageChannel, MpcSignProtocol, Sign}; use crate::rpc::{ContractStateWatcher, RpcChannel}; use crate::storage::secret_storage::SecretNodeStorageBox; use crate::storage::{PresignatureStorage, TripleStorage}; use near_sdk::AccountId; -use tokio::sync::{mpsc, watch, RwLock}; +use tokio::sync::{mpsc, watch}; pub struct TestProtocolStorage { pub secret_storage: SecretNodeStorageBox, @@ -17,7 +16,6 @@ pub struct TestProtocolStorage { } pub struct TestProtocolChannels { - pub sign_rx: Arc>>, pub msg_channel: MessageChannel, pub rpc_channel: RpcChannel, pub config: watch::Receiver, @@ -30,25 +28,38 @@ impl MpcSignProtocol { storage: TestProtocolStorage, channels: TestProtocolChannels, contract: ContractStateWatcher, + sign_rx: mpsc::Receiver, ) -> Self { let generating = channels.msg_channel.subscribe_generation().await; let resharing = channels.msg_channel.subscribe_resharing().await; let ready = channels.msg_channel.subscribe_ready().await; + let backlog = Backlog::new(); + + // Start the signature spawner task immediately + let sign_task = SignatureSpawnerTask::run( + my_account_id.clone(), + sign_rx, + contract.clone(), + channels.config.clone(), + storage.presignature_storage.clone(), + channels.mesh_state.clone(), + channels.msg_channel.clone(), + channels.rpc_channel.clone(), + backlog.clone(), + ); + Self { my_account_id, secret_storage: storage.secret_storage, triple_storage: storage.triple_storage, presignature_storage: storage.presignature_storage, - sign_rx: channels.sign_rx, + sign_task, msg_channel: channels.msg_channel, generating, resharing, ready, - rpc_channel: channels.rpc_channel, - contract, config: channels.config, mesh_state: channels.mesh_state, - backlog: Backlog::new(), } } } diff --git a/chain-signatures/node/src/rpc.rs b/chain-signatures/node/src/rpc.rs index ba98c343..29c75281 100644 --- a/chain-signatures/node/src/rpc.rs +++ b/chain-signatures/node/src/rpc.rs @@ -7,6 +7,9 @@ use crate::protocol::contract::RunningContractState; use crate::protocol::{Chain, Governance, IndexedSignRequest, ProtocolState, SignRequestType}; use crate::util::AffinePointExt as _; +use mpc_crypto::PublicKey; +use std::collections::BTreeSet; + use solana_sdk::commitment_config::CommitmentConfig; use solana_sdk::pubkey::Pubkey; use solana_sdk::signer::keypair::Keypair; @@ -133,6 +136,17 @@ impl RpcChannel { } } +/// Governance information obtained from the contract state. +/// Updated whenever the contract transitions to a new Running state. +#[derive(Debug, Clone)] +pub struct GovernanceInfo { + pub me: Participant, + pub threshold: usize, + pub epoch: u64, + pub public_key: PublicKey, + pub participants: BTreeSet, +} + #[derive(Clone)] pub struct ContractStateWatcher { account_id: AccountId, @@ -187,7 +201,7 @@ impl ContractStateWatcher { ) } - pub fn account_id(&self) -> &AccountId { + pub fn my_account_id(&self) -> &AccountId { &self.account_id } @@ -307,6 +321,44 @@ impl ContractStateWatcher { } } + /// Waits for the contract to be in Running state and extracts governance info. + /// Keeps waiting if we are not a participant in the current running state. + pub async fn wait_governance(&mut self) -> GovernanceInfo { + loop { + let running = self.wait_running().await; + if let Some(me) = running.participants.find_participant(&self.account_id) { + return GovernanceInfo { + me: *me, + threshold: running.threshold, + epoch: running.epoch, + public_key: PublicKey::from(running.public_key), + participants: running.participants.keys().copied().collect(), + }; + } + // We're not a participant, wait for next state change + tracing::warn!("wait_governance: we are not a participant, waiting for state change"); + let _ = self.contract_state.changed().await; + } + } + + /// Try to get governance info from current state without waiting. + /// Returns None if not in Running state or if we're not a participant. + pub fn governance(&self) -> Option { + match self.borrow_state().as_ref()? { + ProtocolState::Running(running) => { + let me = running.participants.find_participant(&self.account_id)?; + Some(GovernanceInfo { + me: *me, + threshold: running.threshold, + epoch: running.epoch, + public_key: PublicKey::from(running.public_key), + participants: running.participants.keys().copied().collect(), + }) + } + _ => None, + } + } + /// Create a list of contract states that share a single channel but use different account ids. #[cfg(feature = "test-feature")] pub fn test_batch( diff --git a/integration-tests/src/actions/sign.rs b/integration-tests/src/actions/sign.rs index d667e50f..820f0360 100644 --- a/integration-tests/src/actions/sign.rs +++ b/integration-tests/src/actions/sign.rs @@ -852,7 +852,13 @@ impl SignAction<'_> { .await?; let err = wait_for::rogue_message_responded(rogue_status).await?; - assert!(err.contains(&errors::RespondError::InvalidSignature.to_string())); + // The rogue respond can race with the honest one; if the honest response lands first + // the contract returns RequestNotFound instead of InvalidSignature. + assert!( + err.contains(&errors::RespondError::InvalidSignature.to_string()) + || err.contains(&errors::InvalidParameters::RequestNotFound.to_string()), + "unexpected rogue respond error: {err}" + ); Some(rogue) } else { None diff --git a/integration-tests/src/mpc_fixture/builder.rs b/integration-tests/src/mpc_fixture/builder.rs index 63799c84..6ffb12e4 100644 --- a/integration-tests/src/mpc_fixture/builder.rs +++ b/integration-tests/src/mpc_fixture/builder.rs @@ -27,10 +27,8 @@ use mpc_node::rpc::RpcChannel; use mpc_node::storage::{secret_storage, triple_storage::TriplePair, Options}; use near_sdk::AccountId; use std::collections::HashMap; -use std::sync::Arc; use tokio::sync::mpsc::{self, Sender}; use tokio::sync::watch; -use tokio::sync::RwLock; pub struct MpcFixtureBuilder { prepared_nodes: Vec, @@ -425,7 +423,6 @@ impl MpcFixtureNodeBuilder { let (config_tx, config_rx) = watch::channel(self.config); let channels = protocol::test_setup::TestProtocolChannels { - sign_rx: Arc::new(RwLock::new(sign_rx)), msg_channel: self.messaging.channel.clone(), rpc_channel, config: config_rx.clone(), @@ -446,6 +443,7 @@ impl MpcFixtureNodeBuilder { storage, channels, context.contract_state.clone(), + sign_rx, ) .await; diff --git a/integration-tests/src/mpc_fixture/fixture_interface.rs b/integration-tests/src/mpc_fixture/fixture_interface.rs index 8b78fd43..f67d9c7e 100644 --- a/integration-tests/src/mpc_fixture/fixture_interface.rs +++ b/integration-tests/src/mpc_fixture/fixture_interface.rs @@ -6,7 +6,8 @@ use cait_sith::protocol::Participant; use mpc_node::backlog::Backlog; use mpc_node::config::Config; use mpc_node::mesh::MeshState; -use mpc_node::protocol::state::NodeStateWatcher; +use mpc_node::protocol::contract::ResharingContractState; +use mpc_node::protocol::state::{NodeStateWatcher, NodeStatus}; use mpc_node::protocol::sync::SyncChannel; use mpc_node::protocol::{MessageChannel, ProtocolState, Sign}; use mpc_node::storage::{PresignatureStorage, TripleStorage}; @@ -75,6 +76,65 @@ impl MpcFixture { } } + /// Wait for all nodes to reach the Running state. + pub async fn wait_for_running(&self) { + for node in &self.nodes { + node.wait_for_running().await; + } + } + + /// Trigger resharing by transitioning the contract state to Resharing. + /// The participants remain the same (no node joins or leaves). + pub fn trigger_resharing(&self) { + let current_state = self.shared_contract_state.borrow().clone(); + if let Some(ProtocolState::Running(running)) = current_state { + let resharing_state = ResharingContractState { + old_epoch: running.epoch, + old_participants: running.participants.clone(), + new_participants: running.participants.clone(), + threshold: running.threshold, + public_key: running.public_key, + finished_votes: HashSet::new(), + cancel_votes: HashSet::new(), + }; + let _ = self + .shared_contract_state + .send(Some(ProtocolState::Resharing(resharing_state))); + tracing::info!("triggered resharing"); + } else { + tracing::warn!("cannot trigger resharing: contract not in Running state"); + } + } + + /// Complete resharing by transitioning the contract state back to Running. + /// Note: For testing purposes, we keep the same epoch since we're not actually + /// running the resharing protocol. This simulates a resharing that was cancelled + /// or where the participants didn't change. + pub fn complete_resharing(&self) { + let current_state = self.shared_contract_state.borrow().clone(); + if let Some(ProtocolState::Resharing(resharing)) = current_state { + let running_state = mpc_node::protocol::contract::RunningContractState { + // Keep the same epoch since we're not actually resharing + epoch: resharing.old_epoch, + participants: resharing.new_participants.clone(), + threshold: resharing.threshold, + public_key: resharing.public_key, + candidates: Default::default(), + join_votes: Default::default(), + leave_votes: Default::default(), + }; + let _ = self + .shared_contract_state + .send(Some(ProtocolState::Running(running_state))); + tracing::info!( + epoch = resharing.old_epoch, + "completed resharing (same epoch)" + ); + } else { + tracing::warn!("cannot complete resharing: contract not in Resharing state"); + } + } + /// Print all messages to debug. pub async fn print_msg_log(&self) { let guard = self.output.msg_log.lock().await; @@ -85,6 +145,16 @@ impl MpcFixture { } impl MpcFixtureNode { + /// Wait for this node to reach the Running state. + pub async fn wait_for_running(&self) { + loop { + if matches!(self.state.status(), NodeStatus::Running { .. }) { + return; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + } + pub async fn wait_for_triples(&self, threshold_per_node: usize) { loop { let count = self.triple_storage.len_by_owner(self.me).await; diff --git a/integration-tests/tests/cases/mpc.rs b/integration-tests/tests/cases/mpc.rs index 2250d207..128135bb 100644 --- a/integration-tests/tests/cases/mpc.rs +++ b/integration-tests/tests/cases/mpc.rs @@ -199,6 +199,147 @@ async fn test_basic_sign() { ); } +/// Test that sign tasks survive resharing and complete after resharing finishes. +/// This test: +/// 1. Waits for all nodes to be running +/// 2. Starts a signing request +/// 3. Triggers resharing mid-signature +/// 4. Completes resharing +/// 5. Verifies the signature still completes +#[test(tokio::test(flavor = "multi_thread"))] +async fn test_sign_task_survives_resharing() { + let network = MpcFixtureBuilder::default() + .only_generate_signatures() + .build() + .await; + + // Wait for all nodes to reach Running state + tokio::time::timeout(Duration::from_secs(5), network.wait_for_running()) + .await + .expect("nodes should reach Running state"); + + tokio::time::timeout( + Duration::from_millis(300), + network.wait_for_presignatures(2), + ) + .await + .expect("should start with enough presignatures"); + + tracing::info!("sending sign request"); + let request = sign_request(0); + network[0] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + network[1] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + network[2] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + + // Give some time for the sign task to start organizing + tokio::time::sleep(Duration::from_millis(50)).await; + + // Trigger resharing while sign task is in progress + tracing::info!("triggering resharing"); + network.trigger_resharing(); + + // Give some time for nodes to notice resharing and pause + tokio::time::sleep(Duration::from_millis(200)).await; + + // Complete resharing + tracing::info!("completing resharing"); + network.complete_resharing(); + + // The signature should still complete after resharing + let timeout = Duration::from_secs(15); + let actions = tokio::time::timeout(timeout, network.wait_for_actions(1)) + .await + .expect("signature should complete after resharing"); + + assert_eq!(actions.len(), 1); + let action_str = actions.iter().next().unwrap(); + assert!( + action_str.contains("RpcAction::Publish"), + "unexpected rpc action {action_str}" + ); + tracing::info!("signature completed successfully after resharing"); +} + +/// Test that new sign requests during resharing are queued and complete after resharing. +#[test(tokio::test(flavor = "multi_thread"))] +async fn test_sign_request_during_resharing() { + let network = MpcFixtureBuilder::default() + .only_generate_signatures() + .build() + .await; + + // Wait for all nodes to reach Running state + tokio::time::timeout(Duration::from_secs(5), network.wait_for_running()) + .await + .expect("nodes should reach Running state"); + + tokio::time::timeout( + Duration::from_millis(300), + network.wait_for_presignatures(2), + ) + .await + .expect("should start with enough presignatures"); + + // Trigger resharing first + tracing::info!("triggering resharing before sign request"); + network.trigger_resharing(); + + // Give some time for nodes to notice resharing + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send a sign request while in resharing + tracing::info!("sending sign request during resharing"); + let request = sign_request(1); + network[0] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + network[1] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + network[2] + .sign_tx + .send(Sign::Request(request.clone())) + .await + .unwrap(); + + // Give some time for the request to be queued + tokio::time::sleep(Duration::from_millis(100)).await; + + // Complete resharing + tracing::info!("completing resharing"); + network.complete_resharing(); + + // The signature should complete after resharing + let timeout = Duration::from_secs(15); + let actions = tokio::time::timeout(timeout, network.wait_for_actions(1)) + .await + .expect("signature should complete after resharing"); + + assert_eq!(actions.len(), 1); + let action_str = actions.iter().next().unwrap(); + assert!( + action_str.contains("RpcAction::Publish"), + "unexpected rpc action {action_str}" + ); + tracing::info!("signature request during resharing completed successfully"); +} + fn sign_request(seed: u8) -> IndexedSignRequest { IndexedSignRequest { id: SignId::new([seed; 32]), diff --git a/integration-tests/tests/cases/solana.rs b/integration-tests/tests/cases/solana.rs index b1b0f6c6..6c85d91f 100644 --- a/integration-tests/tests/cases/solana.rs +++ b/integration-tests/tests/cases/solana.rs @@ -40,3 +40,45 @@ async fn test_solana_signature_basic() -> anyhow::Result<()> { anyhow::bail!("signature verification failed"); } } + +// Concurrent variant: spawn many sign requests at once against a very small +// presignature stockpile and assert we make forward progress (no livelock). +#[test(tokio::test)] +async fn test_solana_stockpile_depletion() -> anyhow::Result<()> { + let cluster = cluster::spawn() + .solana() + .with_config(|conf| { + // tiny presignature stock so contention is probable + conf.protocol.presignature.min_presignatures = 2; + conf.protocol.presignature.max_presignatures = 4; + }) + .await?; + + // spawn many concurrent requests and assert each completes within a + // reasonable timeout. This catches stuck/livelock cases deterministically + // and fast because the concurrent pressure increases contention. + let concurrent = 30; + let mut futs = Vec::with_capacity(concurrent); + for _ in 0..concurrent { + // don't spawn new tasks at thread boundary; keep futures local so borrows are simple + futs.push(tokio::time::timeout( + std::time::Duration::from_secs(15), + async { + cluster + .sign() + .solana() + .await + .map_err(|e| anyhow::anyhow!(e)) + }, + )); + } + + // wait for all futures to complete and ensure none timed out or errored + for outcome in futures::future::join_all(futs).await { + if let Err(err) = outcome { + anyhow::bail!("timeout waiting for sign request to complete: {err:?}"); + } + } + + Ok(()) +}