From ce73e21a449955bb85597045d149e13826066c98 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:32:06 +0000 Subject: [PATCH 01/14] feat: optimize address batch pipeline --- forester-utils/src/address_staging_tree.rs | 17 +- forester/src/processor/v2/helpers.rs | 63 ++++-- forester/src/processor/v2/strategy/address.rs | 79 +++---- program-tests/utils/src/e2e_test_env.rs | 103 ++++----- .../utils/src/mock_batched_forester.rs | 23 +- .../utils/src/test_batch_forester.rs | 71 +++--- prover/client/src/errors.rs | 3 + prover/client/src/helpers.rs | 10 +- .../batch_address_append/proof_inputs.rs | 84 ++++--- .../proof_types/batch_append/proof_inputs.rs | 7 +- .../proof_types/batch_update/proof_inputs.rs | 2 +- prover/client/tests/batch_address_append.rs | 23 +- sdk-libs/client/src/indexer/types/queue.rs | 208 ++++++++++++++++-- .../program-test/src/indexer/test_indexer.rs | 20 +- sparse-merkle-tree/src/indexed_changelog.rs | 6 +- sparse-merkle-tree/tests/indexed_changelog.rs | 14 +- 16 files changed, 472 insertions(+), 261 deletions(-) diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs index 786ddb6ac0..a6b1aa89bd 100644 --- a/forester-utils/src/address_staging_tree.rs +++ b/forester-utils/src/address_staging_tree.rs @@ -121,7 +121,7 @@ impl AddressStagingTree { low_element_next_values: &[[u8; 32]], low_element_indices: &[u64], low_element_next_indices: &[u64], - low_element_proofs: &[Vec<[u8; 32]>], + low_element_proofs: &[[[u8; 32]; HEIGHT]], leaves_hashchain: [u8; 32], zkp_batch_size: usize, epoch: u64, @@ -145,15 +145,12 @@ impl AddressStagingTree { let inputs = get_batch_address_append_circuit_inputs::( next_index, old_root, - low_element_values.to_vec(), - low_element_next_values.to_vec(), - low_element_indices.iter().map(|v| *v as usize).collect(), - low_element_next_indices - .iter() - .map(|v| *v as usize) - .collect(), - low_element_proofs.to_vec(), - addresses.to_vec(), + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, &mut self.sparse_tree, leaves_hashchain, zkp_batch_size, diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index ed135cb6a4..a0f3e3bb5b 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -9,6 +9,7 @@ use light_client::{ indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, rpc::Rpc, }; +use light_hasher::hash_chain::create_hash_chain_from_slice; use crate::processor::v2::{common::clamp_to_u16, BatchContext}; @@ -22,6 +23,17 @@ pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> Mu } } +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { let rpc = context.rpc_pool.get_connection().await?; let mut account = rpc @@ -474,20 +486,52 @@ impl StreamingAddressQueue { } } - pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { let available = self.wait_for_batch(end); if start >= available { - return None; + return Ok(None); } let actual_end = end.min(available); let data = lock_recover(&self.data, "streaming_address_queue.data"); - Some(BatchDataSlice { - addresses: data.addresses[start..actual_end].to_vec(), + + let addresses = data.addresses[start..actual_end].to_vec(); + if addresses.is_empty() { + return Err(anyhow!("Empty batch at start={}", start)); + } + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { low_element_values: data.low_element_values[start..actual_end].to_vec(), low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - }) + low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + addresses, + leaves_hashchain, + })) } pub fn into_data(self) -> AddressQueueData { @@ -553,15 +597,6 @@ impl StreamingAddressQueue { } } -#[derive(Debug, Clone)] -pub struct BatchDataSlice { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, -} - pub async fn fetch_streaming_address_batches( context: &BatchContext, total_elements: u64, diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 06e94d5500..51ab05143a 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -14,11 +14,10 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, - common::get_leaves_hashchain, errors::V2Error, helpers::{ fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - lock_recover, StreamingAddressQueue, + AddressBatchSnapshot, StreamingAddressQueue, }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, @@ -267,9 +266,23 @@ impl BatchJobBuilder for AddressQueueData { let batch_end = start + zkp_batch_size_usize; - let batch_data = self - .streaming_queue - .get_batch_data(start, batch_end) + let streaming_queue = &self.streaming_queue; + let staging_tree = &mut self.staging_tree; + let hashchain_idx = start / zkp_batch_size_usize; + let AddressBatchSnapshot { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + leaves_hashchain, + } = streaming_queue + .get_batch_snapshot::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + start, + batch_end, + hashchain_idx, + )? .ok_or_else(|| { anyhow!( "Batch data not available: start={}, end={}, available={}", @@ -278,31 +291,21 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.available_batches() * zkp_batch_size_usize ) })?; - - let addresses = &batch_data.addresses; let zkp_batch_size_actual = addresses.len(); - - if zkp_batch_size_actual == 0 { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let low_element_values = &batch_data.low_element_values; - let low_element_next_values = &batch_data.low_element_next_values; - let low_element_indices = &batch_data.low_element_indices; - let low_element_next_indices = &batch_data.low_element_next_indices; - - let low_element_proofs: Vec> = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - (start..start + zkp_batch_size_actual) - .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) - .collect::, _>>()? - }; - - let hashchain_idx = start / zkp_batch_size_usize; - let leaves_hashchain = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? - }; + let result = staging_tree + .process_batch( + &addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ) + .map_err(|err| map_address_staging_error(tree, err))?; let tree_batch = tree_next_index / zkp_batch_size_usize; let absolute_index = data_start + start; @@ -318,24 +321,6 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.is_complete() ); - let result = self.staging_tree.process_batch( - addresses, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - &low_element_proofs, - leaves_hashchain, - zkp_batch_size_actual, - epoch, - tree, - ); - - let result = match result { - Ok(r) => r, - Err(err) => return Err(map_address_staging_error(tree, err)), - }; - Ok(Some(( ProofInput::AddressAppend(result.circuit_inputs), result.new_root, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index a16a7925a7..097ae1f9a9 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -73,7 +73,6 @@ use account_compression::{ use anchor_lang::{prelude::AccountMeta, AnchorSerialize, Discriminator}; use create_address_test_program::create_invoke_cpi_instruction; use forester_utils::{ - account_zero_copy::AccountZeroCopy, address_merkle_tree_config::{address_tree_ready_for_rollover, state_tree_ready_for_rollover}, forester_epoch::{Epoch, Forester, TreeAccounts}, utils::airdrop_lamports, @@ -194,6 +193,7 @@ use crate::{ }, test_batch_forester::{perform_batch_append, perform_batch_nullify}, test_forester::{empty_address_queue_test, nullify_compressed_accounts}, + AccountZeroCopy, }; pub struct User { @@ -748,70 +748,67 @@ where .with_address_queue(None, Some(batch.batch_size as u16)); let result = self .indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) + .get_queue_elements( + merkle_tree_pubkey.to_bytes(), + options, + None, + ) .await .unwrap(); - let addresses = result - .value - .address_queue - .map(|aq| aq.addresses) - .unwrap_or_default(); + let address_queue = result.value.address_queue.unwrap(); + let low_element_proofs = address_queue + .reconstruct_all_proofs::<{ + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize + }>() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = self - .indexer - .get_multiple_new_address_proofs( - merkle_tree_pubkey.to_bytes(), - addresses.clone(), - None, - ) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices - .push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices - .push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values - .push(non_inclusion_proof.low_address_next_value); - - low_element_proofs - .push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = self.indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); - let mut sparse_merkle_tree = SparseMerkleTree::::new(<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items).unwrap(), start_index); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + subtrees, + .. + } = address_queue; + let mut sparse_merkle_tree = SparseMerkleTree::< + Poseidon, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >::new( + subtrees.as_slice().try_into().unwrap(), + start_index, + ); - let mut changelog: Vec> = Vec::new(); - let mut indexed_changelog: Vec> = Vec::new(); + let mut changelog: Vec< + ChangelogEntry<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>, + > = Vec::new(); + let mut indexed_changelog: Vec< + IndexedChangelogEntry< + usize, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >, + > = Vec::new(); let inputs = get_batch_address_append_circuit_inputs::< { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, >( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -834,9 +831,13 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| RpcError::CustomError(error.to_string())) + .unwrap(); + let (proof_a, proof_b, proof_c) = + proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 4458aa03b3..2a93f772ba 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -260,7 +260,7 @@ impl MockBatchedAddressForester { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; HEIGHT]> = Vec::new(); for new_element_value in &new_element_values { let non_inclusion_proof = self .merkle_tree @@ -270,7 +270,14 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + let proof = non_inclusion_proof.merkle_proof.as_slice().try_into().map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; + low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); let mut merkle_tree = match <[[u8; 32]; HEIGHT]>::try_from(subtrees) { @@ -287,12 +294,12 @@ impl MockBatchedAddressForester { let inputs = match get_batch_address_append_circuit_inputs::( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values.clone(), + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut merkle_tree, leaves_hashchain, zkp_batch_size as usize, diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 8e6909704f..0952a7fd76 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -165,7 +165,9 @@ pub async fn create_append_batch_ix_data( bundle.merkle_tree.root() ); let proof_client = ProofClient::local(); - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs) + .to_string() + ; match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => ( @@ -319,13 +321,13 @@ pub async fn get_batched_nullify_ix_data( }) } -use forester_utils::{ - account_zero_copy::AccountZeroCopy, instructions::create_account::create_account_instruction, -}; +use forester_utils::instructions::create_account::create_account_instruction; use light_client::indexer::{Indexer, QueueElementsV2Options}; use light_program_test::indexer::state_tree::StateMerkleTreeBundle; use light_sparse_merkle_tree::SparseMerkleTree; +use crate::AccountZeroCopy; + pub async fn assert_registry_created_batched_state_merkle_tree( rpc: &mut R, payer_pubkey: Pubkey, @@ -663,50 +665,33 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_slice(addresses.as_slice()).unwrap(); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 1, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = indexer - .get_multiple_new_address_proofs(merkle_tree_pubkey.to_bytes(), addresses.clone(), None) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices.push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices.push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values.push(non_inclusion_proof.low_address_next_value); - - low_element_proofs.push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + subtrees, + .. + } = address_queue; let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new( - <[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items) - .unwrap(), - start_index, - ); + >::new(subtrees.as_slice().try_into().unwrap(), start_index); let mut changelog: Vec> = Vec::new(); @@ -718,12 +703,12 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index 85c1bc8fbe..e095bf3579 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -37,6 +37,9 @@ pub enum ProverClientError { #[error("Invalid proof data: {0}")] InvalidProofData(String), + #[error("Integer conversion failed: {0}")] + IntegerConversion(String), + #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 6ea223e79f..98457479e2 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -6,6 +6,8 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; +use crate::errors::ProverClientError; + pub fn get_project_root() -> Option { let output = Command::new("git") .args(["rev-parse", "--show-toplevel"]) @@ -48,7 +50,7 @@ pub fn compute_root_from_merkle_proof( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], path_index: u32, -) -> ([u8; 32], ChangelogEntry) { +) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); let mut current_hash = leaf; @@ -56,14 +58,14 @@ pub fn compute_root_from_merkle_proof( for (level, path_element) in path_elements.iter().enumerate() { changelog_entry.path[level] = Some(current_hash); if current_index.is_multiple_of(2) { - current_hash = Poseidon::hashv(&[¤t_hash, path_element]).unwrap(); + current_hash = Poseidon::hashv(&[¤t_hash, path_element])?; } else { - current_hash = Poseidon::hashv(&[path_element, ¤t_hash]).unwrap(); + current_hash = Poseidon::hashv(&[path_element, ¤t_hash])?; } current_index /= 2; } - (current_hash, changelog_entry) + Ok((current_hash, changelog_entry)) } pub fn big_uint_to_string(big_uint: &BigUint) -> String { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index f80e8d49e4..32408fdc02 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; use light_hasher::{ bigint::bigint_to_be_bytes_array, @@ -187,21 +187,28 @@ impl BatchAddressAppendInputs { pub fn get_batch_address_append_circuit_inputs( next_index: usize, current_root: [u8; 32], - low_element_values: Vec<[u8; 32]>, - low_element_next_values: Vec<[u8; 32]>, - low_element_indices: Vec, - low_element_next_indices: Vec, - low_element_proofs: Vec>, - new_element_values: Vec<[u8; 32]>, + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[impl Copy + TryInto + Debug], + low_element_next_indices: &[impl Copy + TryInto + Debug], + low_element_proofs: &[[[u8; 32]; HEIGHT]], + new_element_values: &[[u8; 32]], sparse_merkle_tree: &mut SparseMerkleTree, leaves_hashchain: [u8; 32], zkp_batch_size: usize, changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); - - let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + let new_element_values = &new_element_values[..zkp_batch_size]; + let mut new_root = [0u8; 32]; + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + + let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) })?; if computed_hashchain != leaves_hashchain { @@ -229,15 +236,6 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = vec![]; - let mut new_element_circuit_merkle_proofs = vec![]; - - let mut patched_low_element_next_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_next_indices: Vec = Vec::new(); - let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_indices: Vec = Vec::new(); - let mut patcher = ChangelogProofPatcher::new::(changelog); let is_first_batch = indexed_changelog.is_empty(); @@ -245,21 +243,33 @@ pub fn get_batch_address_append_circuit_inputs( for i in 0..new_element_values.len() { let mut changelog_index = 0; + let low_element_index = low_element_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element index {:?} does not fit into usize", + low_element_indices[i] + )) + })?; + let low_element_next_index = low_element_next_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element next index {:?} does not fit into usize", + low_element_next_indices[i] + )) + })?; let new_element_index = next_index + i; let mut low_element: IndexedElement = IndexedElement { - index: low_element_indices[i], + index: low_element_index, value: BigUint::from_bytes_be(&low_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; let mut new_element: IndexedElement = IndexedElement { index: new_element_index, value: BigUint::from_bytes_be(&new_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof = low_element_proofs[i]; let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); patch_indexed_changelogs( 0, @@ -293,18 +303,10 @@ pub fn get_batch_address_append_circuit_inputs( next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; + let low_element_changelog_proof = low_element_proof; let intermediate_root = { - let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof - .clone() - .try_into() - .map_err(|v: Vec<[u8; 32]>| { - ProverClientError::ProofPatchFailed(format!( - "low element proof length mismatch: expected {}, got {}", - HEIGHT, - v.len() - )) - })?; + let mut low_element_proof_arr = low_element_changelog_proof; patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); let merkle_proof = low_element_proof_arr; @@ -321,7 +323,7 @@ pub fn get_batch_address_append_circuit_inputs( old_low_leaf_hash, &merkle_proof, low_element.index as u32, - ); + )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) .map_err(|e| { @@ -362,7 +364,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_leaf_hash, &merkle_proof, new_low_element.index as u32, - ); + )?; patcher.push_changelog_entry::(changelog, changelog_entry); low_element_circuit_merkle_proofs.push( @@ -376,13 +378,7 @@ pub fn get_batch_address_append_circuit_inputs( }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT] - .try_into() - .map_err(|_| { - ProverClientError::ProofPatchFailed( - "low_element_proof slice conversion failed".to_string(), - ) - })?, + proof: low_element_changelog_proof, changelog_index: indexed_changelog.len(), //change_log_index, }; @@ -409,7 +405,7 @@ pub fn get_batch_address_append_circuit_inputs( new_element_leaf_hash, &merkle_proof_array, current_index as u32, - ); + )?; if i == 0 && changelog.len() == 1 { if sparse_next_idx_before != current_index { @@ -436,7 +432,7 @@ pub fn get_batch_address_append_circuit_inputs( zero_hash, &merkle_proof_array, current_index as u32, - ); + )?; if root_with_zero != intermediate_root { tracing::error!( "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index ef0327ac1d..7dd578e599 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -187,8 +187,11 @@ pub fn get_batch_append_inputs( }; // Update the root based on the current proof and nullifier - let (updated_root, changelog_entry) = - compute_root_from_merkle_proof(final_leaf, &merkle_proof_array, start_index + i as u32); + let (updated_root, changelog_entry) = compute_root_from_merkle_proof( + final_leaf, + &merkle_proof_array, + start_index + i as u32, + )?; new_root = updated_root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2136d01d10..7f8c08e0d1 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -175,7 +175,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index); + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 22f58d5362..ac73c3809e 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -45,7 +45,8 @@ async fn prove_batch_address_append() { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::new(); // Generate non-inclusion proofs for each element for new_element_value in &new_element_values { @@ -57,7 +58,13 @@ async fn prove_batch_address_append() { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); } // Convert big integers to byte arrays @@ -87,12 +94,12 @@ async fn prove_batch_address_append() { get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut sparse_merkle_tree, hash_chain, zkp_batch_size, diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 40e7cc0f6e..de97ca7739 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::super::IndexerError; #[derive(Debug, Clone, PartialEq, Default)] @@ -65,12 +67,10 @@ pub struct AddressQueueData { impl AddressQueueData { /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. - /// The tree_height is needed to know how many levels to traverse. - pub fn reconstruct_proof( + pub fn reconstruct_proof( &self, address_idx: usize, - tree_height: u8, - ) -> Result, IndexerError> { + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -81,10 +81,10 @@ impl AddressQueueData { ), } })?; - let mut proof = Vec::with_capacity(tree_height as usize); + let mut proof = [[0u8; 32]; HEIGHT]; let mut pos = leaf_index; - for level in 0..tree_height { + for (level, proof_element) in proof.iter_mut().enumerate() { let sibling_pos = if pos.is_multiple_of(2) { pos + 1 } else { @@ -114,30 +114,212 @@ impl AddressQueueData { self.node_hashes.len(), ), })?; - proof.push(*hash); + *proof_element = *hash; pos /= 2; } Ok(proof) } + /// Reconstruct a contiguous batch of proofs while reusing a single node lookup table. + pub fn reconstruct_proofs( + &self, + address_range: std::ops::Range, + ) -> Result, IndexerError> { + let node_lookup = self.build_node_lookup(); + let mut proofs = Vec::with_capacity(address_range.len()); + + for address_idx in address_range { + proofs.push(self.reconstruct_proof_with_lookup::(address_idx, &node_lookup)?); + } + + Ok(proofs) + } + /// Reconstruct all proofs for all addresses - pub fn reconstruct_all_proofs( + pub fn reconstruct_all_proofs( &self, - tree_height: u8, - ) -> Result>, IndexerError> { - (0..self.addresses.len()) - .map(|i| self.reconstruct_proof(i, tree_height)) + ) -> Result, IndexerError> { + self.reconstruct_proofs::(0..self.addresses.len()) + } + + fn build_node_lookup(&self) -> HashMap { + self.nodes + .iter() + .copied() + .enumerate() + .map(|(idx, node)| (node, idx)) .collect() } + fn reconstruct_proof_with_lookup( + &self, + address_idx: usize, + node_lookup: &HashMap, + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "address_idx {} out of bounds for low_element_indices (len {})", + address_idx, + self.low_element_indices.len(), + ), + } + })?; + let mut proof = [[0u8; 32]; HEIGHT]; + let mut pos = leaf_index; + + for (level, proof_element) in proof.iter_mut().enumerate() { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let sibling_idx = Self::encode_node_index(level, sibling_pos); + let hash_idx = node_lookup.get(&sibling_idx).copied().ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "Missing proof node at level {} position {} (encoded: {})", + level, sibling_pos, sibling_idx + ), + } + })?; + let hash = + self.node_hashes + .get(hash_idx) + .ok_or_else(|| IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "node_hashes index {} out of bounds (len {})", + hash_idx, + self.node_hashes.len(), + ), + })?; + *proof_element = *hash; + pos /= 2; + } + + Ok(proof) + } + /// Encode node index: (level << 56) | position #[inline] - fn encode_node_index(level: u8, position: u64) -> u64 { + fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } } +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, hint::black_box, time::Instant}; + + use super::AddressQueueData; + + fn hash_from_node(node_index: u64) -> [u8; 32] { + let mut hash = [0u8; 32]; + hash[..8].copy_from_slice(&node_index.to_le_bytes()); + hash[8..16].copy_from_slice(&node_index.rotate_left(17).to_le_bytes()); + hash[16..24].copy_from_slice(&node_index.rotate_right(9).to_le_bytes()); + hash[24..32].copy_from_slice(&(node_index ^ 0xA5A5_A5A5_A5A5_A5A5).to_le_bytes()); + hash + } + + fn build_queue_data(num_addresses: usize) -> AddressQueueData { + let low_element_indices = (0..num_addresses) + .map(|i| (i as u64).saturating_mul(2)) + .collect::>(); + let mut nodes = BTreeMap::new(); + + for &leaf_index in &low_element_indices { + let mut pos = leaf_index; + for level in 0..HEIGHT { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let node_index = ((level as u64) << 56) | sibling_pos; + nodes + .entry(node_index) + .or_insert_with(|| hash_from_node(node_index)); + pos /= 2; + } + } + + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes.into_iter().unzip(); + + AddressQueueData { + addresses: vec![[0u8; 32]; num_addresses], + low_element_values: vec![[1u8; 32]; num_addresses], + low_element_next_values: vec![[2u8; 32]; num_addresses], + low_element_indices, + low_element_next_indices: (0..num_addresses).map(|i| (i as u64) + 1).collect(), + nodes, + node_hashes, + initial_root: [9u8; 32], + leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], + subtrees: vec![[4u8; 32]; HEIGHT], + start_index: 0, + root_seq: 0, + } + } + + #[test] + fn batched_reconstruction_matches_individual_reconstruction() { + let queue = build_queue_data::<40>(128); + + let expected = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::<40>(i).unwrap()) + .collect::>(); + let actual = queue + .reconstruct_proofs::<40>(0..queue.addresses.len()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + #[ignore = "profiling helper"] + fn profile_reconstruct_proofs_batch() { + const HEIGHT: usize = 40; + const NUM_ADDRESSES: usize = 2_048; + const ITERS: usize = 25; + + let queue = build_queue_data::(NUM_ADDRESSES); + + let baseline_start = Instant::now(); + for _ in 0..ITERS { + let proofs = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::(i).unwrap()) + .collect::>(); + black_box(proofs); + } + let baseline = baseline_start.elapsed(); + + let batched_start = Instant::now(); + for _ in 0..ITERS { + black_box( + queue + .reconstruct_proofs::(0..queue.addresses.len()) + .unwrap(), + ); + } + let batched = batched_start.elapsed(); + + println!( + "queue reconstruction profile: addresses={}, height={}, iters={}, individual={:?}, batched={:?}, speedup={:.2}x", + NUM_ADDRESSES, + HEIGHT, + ITERS, + baseline, + batched, + baseline.as_secs_f64() / batched.as_secs_f64(), + ); + } +} + /// V2 Queue Elements Result with deduplicated node data #[derive(Debug, Clone, PartialEq, Default)] pub struct QueueElementsResult { diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..cbaea17320 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2170,18 +2170,20 @@ impl TestIndexer { let inclusion_proof_inputs = InclusionProofInputs::new(inclusion_proofs.as_slice()).unwrap(); ( - Some(BatchInclusionJsonStruct::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inclusion_proof_inputs), + ), None, ) } else if height == STATE_MERKLE_TREE_HEIGHT as usize { let inclusion_proof_inputs = InclusionProofInputsLegacy(inclusion_proofs.as_slice()); ( None, - Some(BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + ), + ), ) } else { return Err(IndexerError::CustomError( @@ -2358,7 +2360,11 @@ impl TestIndexer { if let Some(payload) = payload { (indices, Vec::new(), payload.to_string()) } else { - (indices, Vec::new(), payload_legacy.unwrap().to_string()) + ( + indices, + Vec::new(), + payload_legacy.unwrap().to_string(), + ) } } (None, Some(addresses)) => { diff --git a/sparse-merkle-tree/src/indexed_changelog.rs b/sparse-merkle-tree/src/indexed_changelog.rs index bbd30e1ee6..7e6a26cff7 100644 --- a/sparse-merkle-tree/src/indexed_changelog.rs +++ b/sparse-merkle-tree/src/indexed_changelog.rs @@ -29,7 +29,7 @@ pub fn patch_indexed_changelogs( low_element: &mut IndexedElement, new_element: &mut IndexedElement, low_element_next_value: &mut BigUint, - low_leaf_proof: &mut Vec<[u8; 32]>, + low_leaf_proof: &mut [[u8; 32]; HEIGHT], ) -> Result<(), SparseMerkleTreeError> { // Tests are in program-tests/merkle-tree/tests/indexed_changelog.rs let next_indexed_changelog_indices: Vec = (*indexed_changelogs) @@ -69,7 +69,7 @@ pub fn patch_indexed_changelogs( // Patch the next value. *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value); // Patch the proof. - *low_leaf_proof = changelog_entry.proof.to_vec(); + *low_leaf_proof = changelog_entry.proof; } // If we found a new low element. @@ -82,7 +82,7 @@ pub fn patch_indexed_changelogs( next_index: new_low_element_changelog_entry.element.next_index, }; - *low_leaf_proof = new_low_element_changelog_entry.proof.to_vec(); + *low_leaf_proof = new_low_element_changelog_entry.proof; new_element.next_index = low_element.next_index; if new_low_element_changelog_index == indexed_changelogs.len() - 1 { return Ok(()); diff --git a/sparse-merkle-tree/tests/indexed_changelog.rs b/sparse-merkle-tree/tests/indexed_changelog.rs index 7d37142b46..59efda6fde 100644 --- a/sparse-merkle-tree/tests/indexed_changelog.rs +++ b/sparse-merkle-tree/tests/indexed_changelog.rs @@ -92,7 +92,8 @@ fn test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -114,7 +115,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -124,7 +125,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); println!("patched -------------------"); @@ -206,7 +207,8 @@ fn debug_test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -228,7 +230,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -238,7 +240,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); man_indexed_array.elements[low_element.index()] = low_element.clone(); From d0e469796f4dbd8e22703f5636d6dfa0c46b3e94 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 19:30:35 +0000 Subject: [PATCH 02/14] format --- .../utils/src/mock_batched_forester.rs | 18 ++++++++++------- .../utils/src/test_batch_forester.rs | 4 +--- .../program-test/src/indexer/test_indexer.rs | 20 +++++++------------ .../tests/integration_tests.rs | 2 +- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 2a93f772ba..0101b235aa 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -270,13 +270,17 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - let proof = non_inclusion_proof.merkle_proof.as_slice().try_into().map_err(|_| { - ProverClientError::InvalidProofData(format!( - "invalid low element proof length: expected {}, got {}", - HEIGHT, - non_inclusion_proof.merkle_proof.len() - )) - })?; + let proof = non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 0952a7fd76..a28efa9ca3 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -165,9 +165,7 @@ pub async fn create_append_batch_ix_data( bundle.merkle_tree.root() ); let proof_client = ProofClient::local(); - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs) - .to_string() - ; + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => ( diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index cbaea17320..0b5b0583a3 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2170,20 +2170,18 @@ impl TestIndexer { let inclusion_proof_inputs = InclusionProofInputs::new(inclusion_proofs.as_slice()).unwrap(); ( - Some( - BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inclusion_proof_inputs), - ), + Some(BatchInclusionJsonStruct::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + )), None, ) } else if height == STATE_MERKLE_TREE_HEIGHT as usize { let inclusion_proof_inputs = InclusionProofInputsLegacy(inclusion_proofs.as_slice()); ( None, - Some( - BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - ), - ), + Some(BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + )), ) } else { return Err(IndexerError::CustomError( @@ -2360,11 +2358,7 @@ impl TestIndexer { if let Some(payload) = payload { (indices, Vec::new(), payload.to_string()) } else { - ( - indices, - Vec::new(), - payload_legacy.unwrap().to_string(), - ) + (indices, Vec::new(), payload_legacy.unwrap().to_string()) } } (None, Some(addresses)) => { diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 9b40b900e5..2c3e82972a 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, + edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, D9EdgeMixedParams, }; From 4650c0ab33faa272a8809976fb8813585eefbf4f Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 19:42:28 +0000 Subject: [PATCH 03/14] feat: stabilize address batch pipeline --- Cargo.lock | 18 +++++++++--------- prover/client/src/constants.rs | 2 +- prover/client/src/prover.rs | 5 ++++- .../program-test/src/indexer/test_indexer.rs | 4 +++- .../tests/integration_tests.rs | 2 +- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e293fd2b8..f65f0ac59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,9 +4946,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-derive" @@ -11003,30 +11003,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", diff --git a/prover/client/src/constants.rs b/prover/client/src/constants.rs index 18a5c05a45..151bf87918 100644 --- a/prover/client/src/constants.rs +++ b/prover/client/src/constants.rs @@ -1,4 +1,4 @@ -pub const SERVER_ADDRESS: &str = "http://localhost:3001"; +pub const SERVER_ADDRESS: &str = "http://127.0.0.1:3001"; pub const HEALTH_CHECK: &str = "/health"; pub const PROVE_PATH: &str = "/prove"; diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 3bf1bab785..56ae20d98a 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -51,7 +51,10 @@ pub async fn spawn_prover() { } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = reqwest::Client::new(); + let client = match reqwest::Client::builder().no_proxy().build() { + Ok(client) => client, + Err(_) => return false, + }; let mut result = false; for _ in 0..retries { match client diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..c51298b9cd 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,7 +726,9 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - start_index: start as u64, + // Consumers use start_index as the sparse tree's next insertion index, + // not the pagination offset used for queue slicing. + start_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 2c3e82972a..9b40b900e5 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, + edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, D9EdgeMixedParams, }; From ed44b772e86025fed20cf60372daf0f2e4fad9b3 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:32:39 +0000 Subject: [PATCH 04/14] feat: batch cold account loads in light client --- sdk-libs/client/src/indexer/photon_indexer.rs | 21 ++- sdk-libs/client/src/indexer/types/queue.rs | 29 +++- .../client/src/interface/load_accounts.rs | 125 +++++++++++++----- 3 files changed, 126 insertions(+), 49 deletions(-) diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 26d16ae235..5698719c8f 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -1142,17 +1142,16 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - let mut proof_vec = x.proof.clone(); - if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { return Err(IndexerError::InvalidParameters(format!( "Merkle proof length ({}) is less than canopy depth ({})", - proof_vec.len(), + x.proof.len(), STATE_MERKLE_TREE_CANOPY_DEPTH, ))); } - proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH); + let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; - let proof = proof_vec + let proof = x.proof[..proof_len] .iter() .map(|s| Hash::from_base58(s)) .collect::, IndexerError>>() @@ -1703,15 +1702,13 @@ impl Indexer for PhotonIndexer { async fn get_subtrees( &self, - _merkle_tree_pubkey: [u8; 32], + merkle_tree_pubkey: [u8; 32], _config: Option, ) -> Result>, IndexerError> { - #[cfg(not(feature = "v2"))] - unimplemented!(); - #[cfg(feature = "v2")] - { - todo!(); - } + Err(IndexerError::NotImplemented(format!( + "PhotonIndexer::get_subtrees is not implemented for merkle tree {}", + solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey) + ))) } } diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index de97ca7739..3f52d72798 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -66,11 +66,14 @@ pub struct AddressQueueData { } impl AddressQueueData { + const ADDRESS_TREE_HEIGHT: usize = 40; + /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. pub fn reconstruct_proof( &self, address_idx: usize, ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -126,6 +129,7 @@ impl AddressQueueData { &self, address_range: std::ops::Range, ) -> Result, IndexerError> { + self.validate_proof_height::()?; let node_lookup = self.build_node_lookup(); let mut proofs = Vec::with_capacity(address_range.len()); @@ -140,16 +144,16 @@ impl AddressQueueData { pub fn reconstruct_all_proofs( &self, ) -> Result, IndexerError> { + self.validate_proof_height::()?; self.reconstruct_proofs::(0..self.addresses.len()) } fn build_node_lookup(&self) -> HashMap { - self.nodes - .iter() - .copied() - .enumerate() - .map(|(idx, node)| (node, idx)) - .collect() + let mut lookup = HashMap::with_capacity(self.nodes.len()); + for (idx, node) in self.nodes.iter().copied().enumerate() { + lookup.entry(node).or_insert(idx); + } + lookup } fn reconstruct_proof_with_lookup( @@ -157,6 +161,7 @@ impl AddressQueueData { address_idx: usize, node_lookup: &HashMap, ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -209,6 +214,18 @@ impl AddressQueueData { fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } + + fn validate_proof_height(&self) -> Result<(), IndexerError> { + if HEIGHT == Self::ADDRESS_TREE_HEIGHT { + return Ok(()); + } + + Err(IndexerError::InvalidParameters(format!( + "address queue proofs require HEIGHT={} but got HEIGHT={}", + Self::ADDRESS_TREE_HEIGHT, + HEIGHT + ))) + } } #[cfg(test)] diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 061ad5074b..c70088dd40 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -53,6 +53,9 @@ pub enum LoadAccountsError { #[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")] MissingPdaCompressed { index: usize, pubkey: Pubkey }, + #[error("Cold PDA (pubkey {pubkey}) missing data")] + MissingPdaCompressedData { pubkey: Pubkey }, + #[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")] MissingAtaCompressed { index: usize, pubkey: Pubkey }, @@ -67,6 +70,7 @@ pub enum LoadAccountsError { } const MAX_ATAS_PER_IX: usize = 8; +const MAX_PDAS_PER_IX: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -113,14 +117,18 @@ where }) .collect(); - let pda_hashes = collect_pda_hashes(&cold_pdas)?; + let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let pda_hashes = pda_groups + .iter() + .map(|group| collect_pda_hashes(group)) + .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; let (pda_proofs, ata_proofs, mint_proofs) = futures::join!( - fetch_proofs(&pda_hashes, indexer), + fetch_proof_batches(&pda_hashes, indexer), fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer), - fetch_proofs(&mint_hashes, indexer), + fetch_individual_proofs(&mint_hashes, indexer), ); let pda_proofs = pda_proofs?; @@ -136,9 +144,9 @@ where // 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs). // Token PDAs are created on-chain via CPI inside DecompressVariant. - for (spec, proof) in cold_pdas.iter().zip(pda_proofs) { + for (group, proof) in pda_groups.into_iter().zip(pda_proofs) { out.push(build_pda_load( - &[spec], + &group, proof, fee_payer, compression_config, @@ -146,8 +154,7 @@ where } // 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist - let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect(); - for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) { + for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) { out.extend(build_ata_load(chunk, proof, fee_payer)?); } @@ -195,23 +202,77 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result, Lo .collect() } -async fn fetch_proofs( +/// Groups already-ordered PDA specs into contiguous runs of the same program id. +/// +/// This preserves input order rather than globally regrouping by program. Callers that +/// want maximal batching across interleaved program ids should sort before calling. +fn group_pda_specs<'a, V>( + specs: &[&'a PdaSpec], + max_per_group: usize, +) -> Vec>> { + assert!(max_per_group > 0, "max_per_group must be non-zero"); + if specs.is_empty() { + return Vec::new(); + } + + let mut groups = Vec::new(); + let mut current = Vec::with_capacity(max_per_group); + let mut current_program: Option = None; + + for spec in specs { + let program_id = spec.program_id(); + let should_split = current_program + .map(|existing| existing != program_id || current.len() >= max_per_group) + .unwrap_or(false); + + if should_split { + groups.push(current); + current = Vec::with_capacity(max_per_group); + } + + current_program = Some(program_id); + current.push(*spec); + } + + if !current.is_empty() { + groups.push(current); + } + + groups +} + +async fn fetch_individual_proofs( hashes: &[[u8; 32]], indexer: &I, ) -> Result, IndexerError> { if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len()); - for hash in hashes { - proofs.push( - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await? - .value, - ); + + futures::future::try_join_all(hashes.iter().map(|hash| async move { + indexer + .get_validity_proof(vec![*hash], vec![], None) + .await + .map(|response| response.value) + })) + .await +} + +async fn fetch_proof_batches( + hash_batches: &[Vec<[u8; 32]>], + indexer: &I, +) -> Result, IndexerError> { + if hash_batches.is_empty() { + return Ok(vec![]); } - Ok(proofs) + + futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { + indexer + .get_validity_proof(hashes.clone(), vec![], None) + .await + .map(|response| response.value) + })) + .await } async fn fetch_proofs_batched( @@ -222,16 +283,13 @@ async fn fetch_proofs_batched( if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size)); - for chunk in hashes.chunks(batch_size) { - proofs.push( - indexer - .get_validity_proof(chunk.to_vec(), vec![], None) - .await? - .value, - ); - } - Ok(proofs) + + let hash_batches = hashes + .chunks(batch_size) + .map(|chunk| chunk.to_vec()) + .collect::>(); + + fetch_proof_batches(&hash_batches, indexer).await } fn build_pda_load( @@ -262,11 +320,16 @@ where let hot_addresses: Vec = specs.iter().map(|s| s.address()).collect(); let cold_accounts: Vec<(CompressedAccount, V)> = specs .iter() - .map(|s| { - let compressed = s.compressed().expect("cold spec must have data").clone(); - (compressed, s.variant.clone()) + .map(|s| -> Result<_, LoadAccountsError> { + let compressed = + s.compressed() + .cloned() + .ok_or(LoadAccountsError::MissingPdaCompressedData { + pubkey: s.address(), + })?; + Ok((compressed, s.variant.clone())) }) - .collect(); + .collect::, _>>()?; let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default(); From f9a788c5ca751b1cf7679704b7e2ba6460b180e5 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sun, 15 Mar 2026 14:42:54 +0000 Subject: [PATCH 05/14] fix: harden load batching and mixed decompression --- forester/src/processor/v2/helpers.rs | 29 ++- forester/src/processor/v2/processor.rs | 6 +- forester/src/processor/v2/proof_worker.rs | 18 +- forester/src/processor/v2/strategy/address.rs | 2 +- program-tests/utils/src/e2e_test_env.rs | 3 +- .../utils/src/mock_batched_forester.rs | 9 +- .../utils/src/test_batch_forester.rs | 8 +- prover/client/src/helpers.rs | 4 +- prover/client/src/proof_client.rs | 31 ++- .../batch_address_append/proof_inputs.rs | 41 ++-- .../proof_types/batch_append/proof_inputs.rs | 2 +- .../proof_types/batch_update/proof_inputs.rs | 2 +- prover/client/src/prover.rs | 200 +++++++++++++++--- prover/client/tests/batch_address_append.rs | 131 ++++++++---- sdk-libs/client/src/indexer/photon_indexer.rs | 19 +- sdk-libs/client/src/indexer/types/queue.rs | 21 ++ .../client/src/interface/load_accounts.rs | 50 +++-- .../program-test/src/indexer/test_indexer.rs | 5 +- .../interface/program/decompression/pda.rs | 11 +- .../program/decompression/processor.rs | 15 +- .../interface/program/decompression/token.rs | 5 +- 21 files changed, 452 insertions(+), 160 deletions(-) diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index a0f3e3bb5b..a9fa2e290d 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -493,12 +493,29 @@ impl StreamingAddressQueue { hashchain_idx: usize, ) -> crate::Result>> { let available = self.wait_for_batch(end); - if start >= available { + if available < end || start >= end { return Ok(None); } - let actual_end = end.min(available); + let actual_end = end; let data = lock_recover(&self.data, "streaming_address_queue.data"); + for (name, len) in [ + ("addresses", data.addresses.len()), + ("low_element_values", data.low_element_values.len()), + ("low_element_next_values", data.low_element_next_values.len()), + ("low_element_indices", data.low_element_indices.len()), + ("low_element_next_indices", data.low_element_next_indices.len()), + ] { + if len < actual_end { + return Err(anyhow!( + "incomplete batch data: {} len {} < required end {}", + name, + len, + actual_end + )); + } + } + let addresses = data.addresses[start..actual_end].to_vec(); if addresses.is_empty() { return Err(anyhow!("Empty batch at start={}", start)); @@ -528,7 +545,9 @@ impl StreamingAddressQueue { low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + low_element_proofs: data.reconstruct_proofs::(start..actual_end).map_err( + |error| anyhow!("incomplete batch data: failed to reconstruct proofs: {error}"), + )?, addresses, leaves_hashchain, })) @@ -566,6 +585,10 @@ impl StreamingAddressQueue { lock_recover(&self.data, "streaming_address_queue.data").start_index } + pub fn tree_next_insertion_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").tree_next_insertion_index + } + pub fn subtrees(&self) -> Vec<[u8; 32]> { lock_recover(&self.data, "streaming_address_queue.data") .subtrees diff --git a/forester/src/processor/v2/processor.rs b/forester/src/processor/v2/processor.rs index 3de6dea860..372a800e0e 100644 --- a/forester/src/processor/v2/processor.rs +++ b/forester/src/processor/v2/processor.rs @@ -132,7 +132,7 @@ where } if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -532,7 +532,7 @@ where ((queue_size / self.zkp_batch_size) as usize).min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -561,7 +561,7 @@ where let max_batches = max_batches.min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index b7afeacf0b..ded9fcedc8 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -132,27 +132,27 @@ struct ProofClients { } impl ProofClients { - fn new(config: &ProverConfig) -> Self { - Self { + fn new(config: &ProverConfig) -> crate::Result { + Ok(Self { append_client: ProofClient::with_config( config.append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, nullify_client: ProofClient::with_config( config.update_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, address_append_client: ProofClient::with_config( config.address_append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), - } + )?, + }) } fn get_client(&self, input: &ProofInput) -> &ProofClient { @@ -164,11 +164,11 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> async_channel::Sender { +pub fn spawn_proof_workers(config: &ProverConfig) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); - let clients = Arc::new(ProofClients::new(config)); + let clients = Arc::new(ProofClients::new(config)?); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); - job_tx + Ok(job_tx) } async fn run_proof_pipeline( diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 51ab05143a..51236c389b 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -167,7 +167,7 @@ impl TreeStrategy for AddressTreeStrategy { } let initial_root = streaming_queue.initial_root(); - let start_index = streaming_queue.start_index(); + let start_index = streaming_queue.tree_next_insertion_index(); let subtrees_arr: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = subtrees.try_into().map_err(|v: Vec<[u8; 32]>| { diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 097ae1f9a9..6c9fdb5d5e 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -764,7 +764,8 @@ where // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = address_queue.start_index as usize; + let start_index = + address_queue.tree_next_insertion_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 0101b235aa..8ea51b7169 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -132,7 +132,8 @@ impl MockBatchedForester { assert_eq!(computed_new_root, self.merkle_tree.root()); - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_append_proof(circuit_inputs) .await { @@ -207,7 +208,8 @@ impl MockBatchedForester { batch_size, &[], )?; - let proof_result = ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = proof_client .generate_batch_update_proof(inputs) .await?; let new_root = self.merkle_tree.root(); @@ -318,7 +320,8 @@ impl MockBatchedAddressForester { ))); } }; - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_address_append_proof(inputs) .await { diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index a28efa9ca3..8cec32757f 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -164,7 +164,7 @@ pub async fn create_append_batch_ix_data( bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap()).unwrap(), bundle.merkle_tree.root() ); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); match proof_client.generate_proof(inputs_json).await { @@ -293,7 +293,7 @@ pub async fn get_batched_nullify_ix_data( &[], ) .unwrap(); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root.to_biguint().unwrap()).unwrap(); let inputs_json = update_inputs_string(&inputs); @@ -670,7 +670,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof= 1, "start index should be greater than 2 else tree is not inited" @@ -715,7 +715,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof(&inputs.new_root).unwrap(); let inputs_json = to_json(&inputs); diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 98457479e2..9a20b8958e 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -49,9 +49,9 @@ pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], - path_index: u32, + path_index: usize, ) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { - let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); + let mut changelog_entry = ChangelogEntry::default_with_index(path_index); let mut current_hash = leaf; let mut current_index = path_index; diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 1d557407bd..820c9fe07d 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -6,7 +6,7 @@ use tokio::time::sleep; use tracing::{debug, error, info, trace, warn}; use crate::{ - constants::PROVE_PATH, + constants::{PROVE_PATH, SERVER_ADDRESS}, errors::ProverClientError, proof::{ compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed, @@ -17,14 +17,13 @@ use crate::{ batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs}, batch_update::{update_inputs_string, BatchUpdateCircuitInputs}, }, + prover::build_http_client, }; const MAX_RETRIES: u32 = 10; const BASE_RETRY_DELAY_SECS: u64 = 1; const DEFAULT_POLLING_INTERVAL_MS: u64 = 100; const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600; -const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001"; - const INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS: u64 = 200; const INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS: u64 = 200; @@ -68,15 +67,15 @@ pub struct ProofClient { } impl ProofClient { - pub fn local() -> Self { - Self { - client: Client::new(), - server_address: DEFAULT_LOCAL_SERVER.to_string(), + pub fn local() -> Result { + Ok(Self { + client: build_http_client()?, + server_address: SERVER_ADDRESS.to_string(), polling_interval: Duration::from_millis(DEFAULT_POLLING_INTERVAL_MS), max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS), api_key: None, initial_poll_delay: Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS), - } + }) } #[allow(unused)] @@ -85,21 +84,21 @@ impl ProofClient { polling_interval: Duration, max_wait_time: Duration, api_key: Option, - ) -> Self { + ) -> Result { let initial_poll_delay = if api_key.is_some() { Duration::from_millis(INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS) } else { Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS) }; - Self { - client: Client::new(), + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } #[allow(unused)] @@ -109,15 +108,15 @@ impl ProofClient { max_wait_time: Duration, api_key: Option, initial_poll_delay: Duration, - ) -> Self { - Self { - client: Client::new(), + ) -> Result { + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } pub async fn submit_proof_async( diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index 32408fdc02..231fa17ac5 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -199,14 +199,31 @@ pub fn get_batch_address_append_circuit_inputs( changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = &new_element_values[..zkp_batch_size]; + let batch_len = zkp_batch_size; + for (name, len) in [ + ("new_element_values", new_element_values.len()), + ("low_element_values", low_element_values.len()), + ("low_element_next_values", low_element_next_values.len()), + ("low_element_indices", low_element_indices.len()), + ("low_element_next_indices", low_element_next_indices.len()), + ("low_element_proofs", low_element_proofs.len()), + ] { + if len < batch_len { + return Err(ProverClientError::GenericError(format!( + "truncated batch from indexer: {} len {} < required batch size {}", + name, len, batch_len + ))); + } + } + + let new_element_values = &new_element_values[..batch_len]; let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); - let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut patched_low_element_next_values = Vec::with_capacity(batch_len); + let mut patched_low_element_next_indices = Vec::with_capacity(batch_len); + let mut patched_low_element_values = Vec::with_capacity(batch_len); + let mut patched_low_element_indices = Vec::with_capacity(batch_len); let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) @@ -241,7 +258,7 @@ pub fn get_batch_address_append_circuit_inputs( let is_first_batch = indexed_changelog.is_empty(); let mut expected_root_for_low = current_root; - for i in 0..new_element_values.len() { + for i in 0..batch_len { let mut changelog_index = 0; let low_element_index = low_element_indices[i].try_into().map_err(|_| { ProverClientError::IntegerConversion(format!( @@ -322,7 +339,7 @@ pub fn get_batch_address_append_circuit_inputs( let (computed_root, _) = compute_root_from_merkle_proof::( old_low_leaf_hash, &merkle_proof, - low_element.index as u32, + low_element.index, )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) @@ -363,7 +380,7 @@ pub fn get_batch_address_append_circuit_inputs( compute_root_from_merkle_proof::( new_low_leaf_hash, &merkle_proof, - new_low_element.index as u32, + new_low_element.index, )?; patcher.push_changelog_entry::(changelog, changelog_entry); @@ -404,7 +421,7 @@ pub fn get_batch_address_append_circuit_inputs( let (updated_root, changelog_entry) = compute_root_from_merkle_proof( new_element_leaf_hash, &merkle_proof_array, - current_index as u32, + current_index, )?; if i == 0 && changelog.len() == 1 { @@ -431,7 +448,7 @@ pub fn get_batch_address_append_circuit_inputs( let (root_with_zero, _) = compute_root_from_merkle_proof::( zero_hash, &merkle_proof_array, - current_index as u32, + current_index, )?; if root_with_zero != intermediate_root { tracing::error!( diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index 7dd578e599..41a6dcfcd6 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -190,7 +190,7 @@ pub fn get_batch_append_inputs( let (updated_root, changelog_entry) = compute_root_from_merkle_proof( final_leaf, &merkle_proof_array, - start_index + i as u32, + start_index as usize + i, )?; new_root = updated_root; changelog.push(changelog_entry); diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 7f8c08e0d1..2ada02b92b 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -175,7 +175,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index as usize)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 56ae20d98a..5b5cfacc77 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,19 +1,120 @@ use std::{ - process::Command, + io::{Read, Write}, + net::{TcpStream, ToSocketAddrs}, + process::{Command, Stdio}, sync::atomic::{AtomicBool, Ordering}, - thread::sleep, time::Duration, }; use tracing::info; +use tokio::time::sleep; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, + errors::ProverClientError, helpers::get_project_root, }; static IS_LOADING: AtomicBool = AtomicBool::new(false); +pub(crate) fn build_http_client() -> Result { + reqwest::Client::builder() + .no_proxy() + .build() + .map_err(|error| { + ProverClientError::GenericError(format!("failed to build HTTP client: {error}")) + }) +} + +fn health_check_once(timeout: Duration) -> bool { + if prover_listener_present() { + return true; + } + + let endpoint = SERVER_ADDRESS + .strip_prefix("http://") + .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) + .unwrap_or(SERVER_ADDRESS); + let addr = match endpoint.to_socket_addrs().ok().and_then(|mut addrs| addrs.next()) { + Some(addr) => addr, + None => return false, + }; + + let mut stream = match TcpStream::connect_timeout(&addr, timeout) { + Ok(stream) => stream, + Err(error) => { + tracing::debug!(?error, endpoint, "prover health TCP connect failed"); + return health_check_once_with_curl(timeout); + } + }; + + let _ = stream.set_read_timeout(Some(timeout)); + let _ = stream.set_write_timeout(Some(timeout)); + + let host = endpoint.split(':').next().unwrap_or("127.0.0.1"); + let request = + format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", HEALTH_CHECK, host); + if let Err(error) = stream.write_all(request.as_bytes()) { + tracing::debug!(?error, "failed to write prover health request"); + return health_check_once_with_curl(timeout); + } + + let mut response = [0u8; 512]; + let bytes_read = match stream.read(&mut response) { + Ok(bytes_read) => bytes_read, + Err(error) => { + tracing::debug!(?error, "failed to read prover health response"); + return health_check_once_with_curl(timeout); + } + }; + + if bytes_read == 0 { + return false; + } + + let response = std::str::from_utf8(&response[..bytes_read]).unwrap_or_default(); + response.contains("200 OK") + || response.contains("{\"status\":\"ok\"}") + || health_check_once_with_curl(timeout) +} + +fn prover_listener_present() -> bool { + let endpoint = SERVER_ADDRESS + .strip_prefix("http://") + .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) + .unwrap_or(SERVER_ADDRESS); + let port = endpoint.rsplit(':').next().unwrap_or("3001"); + + match Command::new("lsof") + .args(["-nP", &format!("-iTCP:{port}"), "-sTCP:LISTEN"]) + .output() + { + Ok(output) => output.status.success() && !output.stdout.is_empty(), + Err(error) => { + tracing::debug!(?error, "failed to execute lsof prover listener check"); + false + } + } +} + +fn health_check_once_with_curl(timeout: Duration) -> bool { + let timeout_secs = timeout.as_secs().max(1).to_string(); + let url = format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK); + match Command::new("curl") + .args(["-sS", "-m", timeout_secs.as_str(), url.as_str()]) + .output() + { + Ok(output) => { + output.status.success() + && String::from_utf8_lossy(&output.stdout).contains("{\"status\":\"ok\"}") + } + Err(error) => { + tracing::debug!(?error, "failed to execute curl prover health check"); + false + } + } +} + pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { let prover_path: &str = { @@ -28,48 +129,81 @@ pub async fn spawn_prover() { } }; - if !health_check(10, 1).await && !IS_LOADING.load(Ordering::Relaxed) { - IS_LOADING.store(true, Ordering::Relaxed); + if health_check(10, 1).await { + return; + } + + if IS_LOADING + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + if health_check(120, 1).await { + return; + } + panic!("Failed to start prover, health check failed."); + } + + let spawn_result = async { + let mut command = Command::new(prover_path); + command.arg("start-prover").stdout(Stdio::piped()).stderr(Stdio::piped()); + let mut child = command.spawn().expect("Failed to start prover process"); + let mut child_exit_status = None; - let command = Command::new(prover_path) - .arg("start-prover") - .spawn() - .expect("Failed to start prover process"); + for _ in 0..120 { + if health_check(1, 1).await { + info!("Prover started successfully"); + return; + } - let _ = command.wait_with_output(); + if child_exit_status.is_none() { + match child.try_wait() { + Ok(Some(status)) => { + tracing::warn!( + ?status, + "prover launcher exited before health check succeeded; continuing to poll for detached prover" + ); + child_exit_status = Some(status); + } + Ok(None) => {} + Err(error) => { + tracing::error!(?error, "failed to poll prover child process"); + } + } + } - let health_result = health_check(120, 1).await; - if health_result { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); + sleep(Duration::from_secs(1)).await; } + + if let Some(status) = child_exit_status { + panic!( + "Failed to start prover, health check failed after launcher exited with status {status}." + ); + } + + panic!("Failed to start prover, health check failed."); } + .await; + + IS_LOADING.store(false, Ordering::Release); + spawn_result } else { panic!("Failed to find project root."); }; } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = match reqwest::Client::builder().no_proxy().build() { - Ok(client) => client, - Err(_) => return false, - }; - let mut result = false; - for _ in 0..retries { - match client - .get(format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK)) - .send() - .await - { - Ok(_) => { - result = true; - break; - } - Err(_) => { - sleep(Duration::from_secs(timeout as u64)); - } + let timeout = Duration::from_secs(timeout as u64); + let retry_delay = timeout; + + for attempt in 0..retries { + if health_check_once(timeout) { + return true; + } + + if attempt + 1 < retries { + sleep(retry_delay).await; } } - result + + false } diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index ac73c3809e..8a02c363ff 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -26,53 +26,55 @@ async fn prove_batch_address_append() { spawn_prover().await; // Initialize test data - let mut new_element_values = vec![]; - let zkp_batch_size = 10; - for i in 1..zkp_batch_size + 1 { - new_element_values.push(num_bigint::ToBigUint::to_biguint(&i).unwrap()); - } + let total_batch_size = 10usize; + let warmup_batch_size = 1usize; + let prior_value = 999_u32.to_biguint().unwrap(); + let new_element_values = (1..=total_batch_size) + .map(|i| num_bigint::ToBigUint::to_biguint(&i).unwrap()) + .collect::>(); // Initialize indexing structures - let relayer_merkle_tree = + let mut relayer_merkle_tree = IndexedMerkleTree::::new(DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize, 0) .unwrap(); - let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; - let current_root = relayer_merkle_tree.root(); + let collect_non_inclusion_data = + |tree: &IndexedMerkleTree, values: &[BigUint]| { + let mut low_element_values = Vec::with_capacity(values.len()); + let mut low_element_indices = Vec::with_capacity(values.len()); + let mut low_element_next_indices = Vec::with_capacity(values.len()); + let mut low_element_next_values = Vec::with_capacity(values.len()); + let mut low_element_proofs: Vec< + [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize], + > = Vec::with_capacity(values.len()); - // Prepare proof components - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = - Vec::new(); + for new_element_value in values { + let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); - // Generate non-inclusion proofs for each element - for new_element_value in &new_element_values { - let non_inclusion_proof = relayer_merkle_tree - .get_non_inclusion_proof(new_element_value) - .unwrap(); + low_element_values.push(non_inclusion_proof.leaf_lower_range_value); + low_element_indices.push(non_inclusion_proof.leaf_index); + low_element_next_indices.push(non_inclusion_proof.next_index); + low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); + } - low_element_values.push(non_inclusion_proof.leaf_lower_range_value); - low_element_indices.push(non_inclusion_proof.leaf_index); - low_element_next_indices.push(non_inclusion_proof.next_index); - low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push( - non_inclusion_proof - .merkle_proof - .as_slice() - .try_into() - .unwrap(), - ); - } + ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) + }; - // Convert big integers to byte arrays - let new_element_values = new_element_values - .iter() - .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) - .collect::>(); - let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let initial_start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let initial_root = relayer_merkle_tree.root(); let subtrees: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = relayer_merkle_tree .merkle_tree @@ -82,7 +84,7 @@ async fn prove_batch_address_append() { let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new(subtrees, start_index); + >::new(subtrees, initial_start_index); let mut changelog: Vec> = Vec::new(); @@ -90,6 +92,55 @@ async fn prove_batch_address_append() { IndexedChangelogEntry, > = Vec::new(); + let warmup_values = vec![prior_value.clone()]; + let ( + warmup_low_element_values, + warmup_low_element_indices, + warmup_low_element_next_indices, + warmup_low_element_next_values, + warmup_low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, &warmup_values); + let warmup_values = warmup_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let warmup_hash_chain = create_hash_chain_from_slice(&warmup_values).unwrap(); + + get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + initial_start_index, + initial_root, + &warmup_low_element_values, + &warmup_low_element_next_values, + &warmup_low_element_indices, + &warmup_low_element_next_indices, + &warmup_low_element_proofs, + &warmup_values, + &mut sparse_merkle_tree, + warmup_hash_chain, + warmup_batch_size, + &mut changelog, + &mut indexed_changelog, + ) + .unwrap(); + + relayer_merkle_tree.append(&prior_value).unwrap(); + + let remaining_values = &new_element_values[..]; + let ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, remaining_values); + let new_element_values = remaining_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let current_root = relayer_merkle_tree.root(); + let inputs = get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, @@ -102,7 +153,7 @@ async fn prove_batch_address_append() { &new_element_values, &mut sparse_merkle_tree, hash_chain, - zkp_batch_size, + total_batch_size, &mut changelog, &mut indexed_changelog, ) diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 5698719c8f..eb8890d6b7 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, time::Duration}; use async_trait::async_trait; use bs58; -use light_sdk_types::constants::STATE_MERKLE_TREE_CANOPY_DEPTH; +use light_sdk_types::constants::{STATE_MERKLE_TREE_CANOPY_DEPTH, STATE_MERKLE_TREE_HEIGHT}; use photon_api::apis::configuration::Configuration; use solana_pubkey::Pubkey; use tracing::{error, trace, warn}; @@ -1142,14 +1142,24 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + let expected_siblings = + STATE_MERKLE_TREE_HEIGHT - STATE_MERKLE_TREE_CANOPY_DEPTH; + let expected_total = STATE_MERKLE_TREE_CANOPY_DEPTH + expected_siblings; + if x.proof.len() != expected_total { return Err(IndexerError::InvalidParameters(format!( - "Merkle proof length ({}) is less than canopy depth ({})", + "Merkle proof length ({}) does not match expected total proof length ({})", x.proof.len(), - STATE_MERKLE_TREE_CANOPY_DEPTH, + expected_total, ))); } let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; + if proof_len != expected_siblings { + return Err(IndexerError::InvalidParameters(format!( + "Merkle proof sibling count ({}) does not match expected sibling count ({})", + proof_len, + expected_siblings, + ))); + } let proof = x.proof[..proof_len] .iter() @@ -1681,6 +1691,7 @@ impl Indexer for PhotonIndexer { .map(|h| super::base58::decode_base58_to_fixed_array(&h.0)) .collect::, _>>()?, start_index: aq.start_index, + tree_next_insertion_index: aq.start_index, root_seq: aq.root_seq, }) } else { diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 3f52d72798..79fcb45cf1 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -61,7 +61,10 @@ pub struct AddressQueueData { pub initial_root: [u8; 32], pub leaves_hash_chains: Vec<[u8; 32]>, pub subtrees: Vec<[u8; 32]>, + /// Pagination offset for the returned queue slice. pub start_index: u64, + /// Sparse tree insertion point / next index used to initialize staging trees. + pub tree_next_insertion_index: u64, pub root_seq: u64, } @@ -130,6 +133,19 @@ impl AddressQueueData { address_range: std::ops::Range, ) -> Result, IndexerError> { self.validate_proof_height::()?; + let available = self.proof_count(); + if address_range.start > address_range.end { + return Err(IndexerError::InvalidParameters(format!( + "invalid address proof range {}..{}", + address_range.start, address_range.end + ))); + } + if address_range.end > available { + return Err(IndexerError::InvalidParameters(format!( + "address proof range {}..{} exceeds available proofs {}", + address_range.start, address_range.end, available + ))); + } let node_lookup = self.build_node_lookup(); let mut proofs = Vec::with_capacity(address_range.len()); @@ -156,6 +172,10 @@ impl AddressQueueData { lookup } + fn proof_count(&self) -> usize { + self.addresses.len().min(self.low_element_indices.len()) + } + fn reconstruct_proof_with_lookup( &self, address_idx: usize, @@ -279,6 +299,7 @@ mod tests { leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], subtrees: vec![[4u8; 32]; HEIGHT], start_index: 0, + tree_next_insertion_index: 0, root_seq: 0, } } diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index c70088dd40..0d564734a2 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -1,5 +1,6 @@ //! Load cold accounts API. +use futures::{stream, StreamExt, TryStreamExt}; use light_account::{derive_rent_sponsor_pda, Pack}; use light_compressed_account::{ compressed_account::PackedMerkleContext, instruction_data::compressed_proof::ValidityProof, @@ -71,6 +72,7 @@ pub enum LoadAccountsError { const MAX_ATAS_PER_IX: usize = 8; const MAX_PDAS_PER_IX: usize = 8; +const PROOF_FETCH_CONCURRENCY: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -118,9 +120,14 @@ where .collect(); let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let mut pda_offset = 0usize; let pda_hashes = pda_groups .iter() - .map(|group| collect_pda_hashes(group)) + .map(|group| { + let hashes = collect_pda_hashes(group, pda_offset)?; + pda_offset += group.len(); + Ok::<_, LoadAccountsError>(hashes) + }) .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; @@ -161,13 +168,16 @@ where Ok(out) } -fn collect_pda_hashes(specs: &[&PdaSpec]) -> Result, LoadAccountsError> { +fn collect_pda_hashes( + specs: &[&PdaSpec], + start_index: usize, +) -> Result, LoadAccountsError> { specs .iter() .enumerate() .map(|(i, s)| { s.hash().ok_or(LoadAccountsError::MissingPdaCompressed { - index: i, + index: start_index + i, pubkey: s.address(), }) }) @@ -249,13 +259,16 @@ async fn fetch_individual_proofs( return Ok(vec![]); } - futures::future::try_join_all(hashes.iter().map(|hash| async move { - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await - .map(|response| response.value) - })) - .await + stream::iter(hashes.iter().copied()) + .map(|hash| async move { + indexer + .get_validity_proof(vec![hash], vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await } async fn fetch_proof_batches( @@ -266,13 +279,16 @@ async fn fetch_proof_batches( return Ok(vec![]); } - futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { - indexer - .get_validity_proof(hashes.clone(), vec![], None) - .await - .map(|response| response.value) - })) - .await + stream::iter(hash_batches.iter().cloned()) + .map(|hashes| async move { + indexer + .get_validity_proof(hashes, vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await } async fn fetch_proofs_batched( diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index c51298b9cd..e799d3f29e 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,9 +726,8 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - // Consumers use start_index as the sparse tree's next insertion index, - // not the pagination offset used for queue slicing. - start_index: address_tree_bundle.right_most_index() as u64, + start_index: start as u64, + tree_next_insertion_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs index 3e32ec6ef3..8819724f59 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs @@ -143,12 +143,21 @@ where let address = derive_address(&pda_key, &ctx.light_config.address_space[0], ctx.program_id); // 10. Build CompressedAccountInfo for CPI + // When PDA decompression is only the first phase of a later token Transfer2 execution, + // the stored input queue index must match that later execution basis, not the original + // packed proof basis. The mixed PDA+token flow uses `output_queue_index` for that. + let input_queue_index = if ctx.cpi_accounts.config().cpi_context { + output_queue_index + } else { + tree_info.queue_pubkey_index + }; + let input = InAccountInfo { data_hash: input_data_hash, lamports: 0, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: tree_info.queue_pubkey_index, + queue_pubkey_index: input_queue_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs index 2fc7cfc811..2fc1f037b8 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs @@ -126,7 +126,7 @@ pub struct DecompressCtx<'a, AI: AccountInfoTrait + Clone> { #[cfg(feature = "token")] pub in_tlv: Option>>, #[cfg(feature = "token")] - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } // ============================================================================ @@ -296,7 +296,7 @@ pub struct DecompressAccountsBuilt<'a, AI: AccountInfoTrait + Clone> { pub cpi_context: bool, pub in_token_data: Vec, pub in_tlv: Option>>, - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } /// Validates accounts, dispatches all variants, and collects CPI inputs for @@ -649,13 +649,20 @@ where .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } else { // At least one regular token account - use invoke_signed with PDA seeds - let signer_seed_refs: Vec<&[u8]> = token_seeds.iter().map(|s| s.as_slice()).collect(); + let signer_seed_storage: Vec> = token_seeds + .iter() + .map(|seed_group| seed_group.iter().map(|seed| seed.as_slice()).collect()) + .collect(); + let signer_seed_refs: Vec<&[&[u8]]> = signer_seed_storage + .iter() + .map(|seed_group| seed_group.as_slice()) + .collect(); AI::invoke_cpi( &LIGHT_TOKEN_PROGRAM_ID, &transfer2_data, &account_metas, remaining_accounts, - &[signer_seed_refs.as_slice()], + signer_seed_refs.as_slice(), ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs index 153943e275..2b6fa37a7a 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs @@ -148,8 +148,9 @@ where ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; - // Push seeds for the Transfer2 CPI (needed for invoke_signed) - ctx.token_seeds.extend(seeds.iter().map(|s| s.to_vec())); + // Push one signer seed group per vault PDA for the later Transfer2 CPI. + ctx.token_seeds + .push(seeds.iter().map(|seed| seed.to_vec()).collect()); } // Push token data for the Transfer2 CPI (common for both ATA and regular paths) From 7ab1c909486b75162eca7d29c7445d611e50a9a5 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 14:50:22 +0000 Subject: [PATCH 06/14] Fix prover startup and decompression load flow --- prover/client/src/errors.rs | 1 - prover/client/src/prover.rs | 81 +++----- sdk-libs/client/src/interface/instructions.rs | 10 +- .../program-test/src/indexer/test_indexer.rs | 191 ++++++++++++++---- .../src/interface/account/token_seeds.rs | 4 +- .../interface/program/decompression/pda.rs | 16 +- .../tests/basic_test.rs | 9 +- 7 files changed, 200 insertions(+), 112 deletions(-) diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index e095bf3579..859cae32b8 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -39,7 +39,6 @@ pub enum ProverClientError { #[error("Integer conversion failed: {0}")] IntegerConversion(String), - #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 5b5cfacc77..e0f9d4060d 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,13 +1,13 @@ use std::{ io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, - process::{Command, Stdio}, + process::Command, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; -use tracing::info; use tokio::time::sleep; +use tracing::info; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, @@ -16,6 +16,18 @@ use crate::{ }; static IS_LOADING: AtomicBool = AtomicBool::new(false); +const STARTUP_HEALTH_CHECK_RETRIES: usize = 300; + +fn has_http_ok_status(response: &[u8]) -> bool { + response + .split(|&byte| byte == b'\n') + .next() + .map(|status_line| { + status_line.starts_with(b"HTTP/") + && status_line.windows(5).any(|window| window == b" 200 ") + }) + .unwrap_or(false) +} pub(crate) fn build_http_client() -> Result { reqwest::Client::builder() @@ -59,7 +71,7 @@ fn health_check_once(timeout: Duration) -> bool { return health_check_once_with_curl(timeout); } - let mut response = [0u8; 512]; + let mut response = [0_u8; 512]; let bytes_read = match stream.read(&mut response) { Ok(bytes_read) => bytes_read, Err(error) => { @@ -68,14 +80,8 @@ fn health_check_once(timeout: Duration) -> bool { } }; - if bytes_read == 0 { - return false; - } - - let response = std::str::from_utf8(&response[..bytes_read]).unwrap_or_default(); - response.contains("200 OK") - || response.contains("{\"status\":\"ok\"}") - || health_check_once_with_curl(timeout) + bytes_read > 0 + && (has_http_ok_status(&response[..bytes_read]) || health_check_once_with_curl(timeout)) } fn prover_listener_present() -> bool { @@ -117,15 +123,15 @@ fn health_check_once_with_curl(timeout: Duration) -> bool { pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { - let prover_path: &str = { + let prover_path = { #[cfg(feature = "devenv")] { - &format!("{}/{}", _project_root.trim(), "cli/test_bin/run") + format!("{}/{}", _project_root.trim(), "cli/test_bin/run") } #[cfg(not(feature = "devenv"))] { println!("Running in production mode, using prover binary"); - "light" + "light".to_string() } }; @@ -137,50 +143,23 @@ pub async fn spawn_prover() { .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) .is_err() { - if health_check(120, 1).await { + if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { return; } panic!("Failed to start prover, health check failed."); } let spawn_result = async { - let mut command = Command::new(prover_path); - command.arg("start-prover").stdout(Stdio::piped()).stderr(Stdio::piped()); - let mut child = command.spawn().expect("Failed to start prover process"); - let mut child_exit_status = None; - - for _ in 0..120 { - if health_check(1, 1).await { - info!("Prover started successfully"); - return; - } - - if child_exit_status.is_none() { - match child.try_wait() { - Ok(Some(status)) => { - tracing::warn!( - ?status, - "prover launcher exited before health check succeeded; continuing to poll for detached prover" - ); - child_exit_status = Some(status); - } - Ok(None) => {} - Err(error) => { - tracing::error!(?error, "failed to poll prover child process"); - } - } - } - - sleep(Duration::from_secs(1)).await; - } - - if let Some(status) = child_exit_status { - panic!( - "Failed to start prover, health check failed after launcher exited with status {status}." - ); + Command::new(&prover_path) + .arg("start-prover") + .spawn() + .unwrap_or_else(|error| panic!("Failed to start prover process: {error}")); + + if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { + info!("Prover started successfully"); + } else { + panic!("Failed to start prover, health check failed."); } - - panic!("Failed to start prover, health check failed."); } .await; diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index f6d754b9b1..e80e7b72c1 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -8,7 +8,7 @@ use light_account::{ CompressedAccountData, InitializeLightConfigParams, Pack, UpdateLightConfigParams, }; use light_sdk::instruction::{ - account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, + account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, PackedStateTreeInfo, SystemAccountMetaConfig, ValidityProof, }; use light_token::constants::{ @@ -247,11 +247,15 @@ where // Process PDAs first, then tokens, to match on-chain split_at(token_accounts_offset). for &i in pda_indices.iter().chain(token_indices.iter()) { let (acc, data) = &cold_accounts[i]; - let _queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); - let tree_info = tree_infos + let proof_tree_info = tree_infos .get(i) .copied() .ok_or("tree info index out of bounds")?; + let queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); + let tree_info = PackedStateTreeInfo { + queue_pubkey_index: queue_index, + ..proof_tree_info + }; let packed_data = data.pack(&mut remaining_accounts)?; typed_accounts.push(CompressedAccountData { diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index e799d3f29e..6e61c3bd98 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -472,8 +472,6 @@ impl Indexer for TestIndexer { let account_data = account.value.ok_or(IndexerError::AccountNotFound)?; state_merkle_tree_pubkeys.push(account_data.tree_info.tree); } - println!("state_merkle_tree_pubkeys {:?}", state_merkle_tree_pubkeys); - println!("hashes {:?}", hashes); let mut proof_inputs = vec![]; let mut indices_to_remove = Vec::new(); @@ -495,14 +493,7 @@ impl Indexer for TestIndexer { .output_queue_elements .iter() .find(|(hash, _)| hash == compressed_account); - println!("queue_element {:?}", queue_element); - if let Some((_, index)) = queue_element { - println!("index {:?}", index); - println!( - "accounts.output_queue_batch_size {:?}", - accounts.output_queue_batch_size - ); if accounts.output_queue_batch_size.is_some() && accounts.leaf_index_in_queue_range(*index as usize)? { @@ -513,12 +504,7 @@ impl Indexer for TestIndexer { hash: *compressed_account, root: [0u8; 32], root_index: RootIndex::new_none(), - leaf_index: accounts - .output_queue_elements - .iter() - .position(|(x, _)| x == compressed_account) - .unwrap() - as u64, + leaf_index: *index, tree_info: light_client::indexer::TreeInfo { cpi_context: Some(accounts.accounts.cpi_context), tree: accounts.accounts.merkle_tree, @@ -2085,6 +2071,106 @@ impl TestIndexer { } } +#[cfg(all(test, feature = "v2"))] +mod tests { + use super::*; + use light_compressed_account::compressed_account::CompressedAccount; + + fn queued_account( + owner: [u8; 32], + merkle_tree: Pubkey, + queue: Pubkey, + leaf_index: u32, + ) -> CompressedAccountWithMerkleContext { + CompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: owner.into(), + lamports: 0, + address: None, + data: None, + }, + merkle_context: MerkleContext { + merkle_tree_pubkey: merkle_tree.to_bytes().into(), + queue_pubkey: queue.to_bytes().into(), + leaf_index, + prove_by_index: false, + tree_type: TreeType::StateV2, + }, + } + } + + #[tokio::test] + async fn get_validity_proof_preserves_sparse_queue_leaf_indices() { + let merkle_tree = Pubkey::new_unique(); + let queue = Pubkey::new_unique(); + let sparse_leaf_indices = [5_u32, 1, 0, 4]; + + let compressed_accounts: Vec<_> = sparse_leaf_indices + .iter() + .enumerate() + .map(|(i, &leaf_index)| { + queued_account([i as u8 + 1; 32], merkle_tree, queue, leaf_index) + }) + .collect(); + let hashes: Vec<_> = compressed_accounts + .iter() + .map(|account| account.hash().unwrap()) + .collect(); + + let output_queue_elements = hashes + .iter() + .zip(sparse_leaf_indices.iter()) + .map(|(hash, &leaf_index)| (*hash, leaf_index as u64)) + .collect(); + + let indexer = TestIndexer { + state_merkle_trees: vec![StateMerkleTreeBundle { + rollover_fee: 0, + network_fee: 0, + merkle_tree: Box::new(MerkleTree::::new_with_history( + DEFAULT_BATCH_STATE_TREE_HEIGHT, + 0, + 0, + DEFAULT_BATCH_ROOT_HISTORY_LEN, + )), + accounts: StateMerkleTreeAccounts { + merkle_tree, + nullifier_queue: queue, + cpi_context: Pubkey::new_unique(), + tree_type: TreeType::StateV2, + }, + tree_type: TreeType::StateV2, + output_queue_elements, + input_leaf_indices: vec![], + output_queue_batch_size: Some(500), + num_inserted_batches: 0, + }], + address_merkle_trees: vec![], + payer: Keypair::new(), + governance_authority: Keypair::new(), + group_pda: Pubkey::new_unique(), + compressed_accounts, + nullified_compressed_accounts: vec![], + token_compressed_accounts: vec![], + token_nullified_compressed_accounts: vec![], + events: vec![], + onchain_pubkey_index: HashMap::new(), + }; + + let response = Indexer::get_validity_proof(&indexer, hashes, vec![], None) + .await + .unwrap(); + let leaf_indices: Vec = response + .value + .accounts + .iter() + .map(|account| account.leaf_index) + .collect(); + + assert_eq!(leaf_indices, sparse_leaf_indices.map(u64::from)); + } +} + impl TestIndexer { async fn process_inclusion_proofs( &self, @@ -2346,7 +2432,16 @@ impl TestIndexer { new_addresses.unwrap().len() ))); } - let client = Client::new(); + let client = Client::builder() + .no_proxy() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(120)) + .build() + .map_err(|error| { + IndexerError::CustomError(format!( + "failed to build prover HTTP client: {error}" + )) + })?; let (account_proof_inputs, address_proof_inputs, json_payload) = match (compressed_accounts, new_addresses) { (Some(accounts), None) => { @@ -2471,6 +2566,7 @@ impl TestIndexer { }; let mut retries = 3; + let mut last_error = "Failed to get proof from server".to_string(); while retries > 0 { let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) @@ -2478,33 +2574,50 @@ impl TestIndexer { .body(json_payload.clone()) .send() .await; - if let Ok(response_result) = response_result { - if response_result.status().is_success() { - let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); - return Ok(ValidityProofWithContext { - accounts: account_proof_inputs, - addresses: address_proof_inputs, - proof: CompressedProof { - a: proof_a, - b: proof_b, - c: proof_c, - } - .into(), - }); + match response_result { + Ok(response_result) => { + let status = response_result.status(); + let body = response_result.text().await.map_err(|error| { + IndexerError::CustomError(format!( + "failed to read prover response body: {error}" + )) + })?; + + if status.is_success() { + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c); + return Ok(ValidityProofWithContext { + accounts: account_proof_inputs, + addresses: address_proof_inputs, + proof: CompressedProof { + a: proof_a, + b: proof_b, + c: proof_c, + } + .into(), + }); + } + + let body_preview: String = body.chars().take(512).collect(); + last_error = format!( + "prover returned HTTP {status} for validity proof request: {body_preview}" + ); } - } else { - println!("Error: {:#?}", response_result); + Err(error) => { + last_error = + format!("failed to contact prover for validity proof: {error}"); + } + } + + retries -= 1; + if retries > 0 { tokio::time::sleep(Duration::from_secs(5)).await; - retries -= 1; } } - Err(IndexerError::CustomError( - "Failed to get proof from server".to_string(), - )) + Err(IndexerError::CustomError(last_error)) } } } diff --git a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs index f22657590a..2bd0ee7bdc 100644 --- a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs +++ b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs @@ -265,7 +265,7 @@ where fn into_in_token_data( &self, tree_info: &PackedStateTreeInfo, - output_queue_index: u8, + _output_queue_index: u8, ) -> Result { Ok(MultiInputTokenDataWithContext { amount: self.token_data.amount, @@ -277,7 +277,7 @@ where root_index: tree_info.root_index, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: output_queue_index, + queue_pubkey_index: tree_info.queue_pubkey_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs index 8819724f59..cc7aa4ba1f 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs @@ -142,22 +142,16 @@ where let pda_key = pda_account.key(); let address = derive_address(&pda_key, &ctx.light_config.address_space[0], ctx.program_id); - // 10. Build CompressedAccountInfo for CPI - // When PDA decompression is only the first phase of a later token Transfer2 execution, - // the stored input queue index must match that later execution basis, not the original - // packed proof basis. The mixed PDA+token flow uses `output_queue_index` for that. - let input_queue_index = if ctx.cpi_accounts.config().cpi_context { - output_queue_index - } else { - tree_info.queue_pubkey_index - }; - + // 10. Build CompressedAccountInfo for CPI. + // Input nullifiers must keep their original queue basis. The later system-program path + // groups nullifiers by queue index, so rewriting mixed PDA+token inputs onto a shared + // output queue drops whole tree/queue pairs from insertion. let input = InAccountInfo { data_hash: input_data_hash, lamports: 0, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: input_queue_index, + queue_pubkey_index: tree_info.queue_pubkey_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs index 747c75ce32..585a828f03 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs @@ -472,13 +472,12 @@ async fn test_create_pdas_and_mint_auto() { .await .expect("create_load_instructions should succeed"); - println!("all_instructions.len() = {:?}", all_instructions); - - // Expected: 1 PDA+Token ix + 2 ATA ixs (1 create_ata + 1 decompress) + 1 mint ix = 4 + // Expected: 1 mint load, 1 grouped PDA/token load, and 2 ATA instructions + // (create ATA + Transfer2 decompression) = 4 total. assert_eq!( all_instructions.len(), - 6, - "Should have 6 instructions: 1 PDA, 1 Token, 2 create_ata, 1 decompress_ata, 1 mint" + 4, + "Should have 4 instructions: 1 mint, 1 grouped PDA/token load, 1 create_ata, 1 ATA Transfer2" ); // Capture rent sponsor balance before decompression From 588127c462396a66ace5d00dba338da73a948de1 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 16:17:20 +0000 Subject: [PATCH 07/14] cleanup: harden prover startup polling --- prover/client/src/prover.rs | 99 +++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index e0f9d4060d..49b9f5aceb 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,7 +1,7 @@ use std::{ io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, - process::Command, + process::{Child, Command}, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; @@ -39,15 +39,15 @@ pub(crate) fn build_http_client() -> Result } fn health_check_once(timeout: Duration) -> bool { - if prover_listener_present() { - return true; - } - let endpoint = SERVER_ADDRESS .strip_prefix("http://") .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) .unwrap_or(SERVER_ADDRESS); - let addr = match endpoint.to_socket_addrs().ok().and_then(|mut addrs| addrs.next()) { + let addr = match endpoint + .to_socket_addrs() + .ok() + .and_then(|mut addrs| addrs.next()) + { Some(addr) => addr, None => return false, }; @@ -64,8 +64,10 @@ fn health_check_once(timeout: Duration) -> bool { let _ = stream.set_write_timeout(Some(timeout)); let host = endpoint.split(':').next().unwrap_or("127.0.0.1"); - let request = - format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", HEALTH_CHECK, host); + let request = format!( + "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", + HEALTH_CHECK, host + ); if let Err(error) = stream.write_all(request.as_bytes()) { tracing::debug!(?error, "failed to write prover health request"); return health_check_once_with_curl(timeout); @@ -84,25 +86,6 @@ fn health_check_once(timeout: Duration) -> bool { && (has_http_ok_status(&response[..bytes_read]) || health_check_once_with_curl(timeout)) } -fn prover_listener_present() -> bool { - let endpoint = SERVER_ADDRESS - .strip_prefix("http://") - .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) - .unwrap_or(SERVER_ADDRESS); - let port = endpoint.rsplit(':').next().unwrap_or("3001"); - - match Command::new("lsof") - .args(["-nP", &format!("-iTCP:{port}"), "-sTCP:LISTEN"]) - .output() - { - Ok(output) => output.status.success() && !output.stdout.is_empty(), - Err(error) => { - tracing::debug!(?error, "failed to execute lsof prover listener check"); - false - } - } -} - fn health_check_once_with_curl(timeout: Duration) -> bool { let timeout_secs = timeout.as_secs().max(1).to_string(); let url = format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK); @@ -121,6 +104,46 @@ fn health_check_once_with_curl(timeout: Duration) -> bool { } } +async fn wait_for_prover_health( + retries: usize, + timeout: Duration, + child: &mut Child, +) -> Result<(), String> { + for attempt in 0..retries { + if health_check_once(timeout) { + return Ok(()); + } + + match child.try_wait() { + Ok(Some(status)) => { + return Err(format!( + "prover process exited before health check succeeded with status {status}" + )); + } + Ok(None) => {} + Err(error) => { + return Err(format!("failed to poll prover process status: {error}")); + } + } + + if attempt + 1 < retries { + sleep(timeout).await; + } + } + + Err(format!( + "prover health check failed after {} attempts", + retries + )) +} + +fn monitor_prover_child(mut child: Child) { + std::thread::spawn(move || match child.wait() { + Ok(status) => tracing::debug!(?status, "prover launcher exited"), + Err(error) => tracing::warn!(?error, "failed to wait on prover launcher"), + }); +} + pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { let prover_path = { @@ -150,15 +173,27 @@ pub async fn spawn_prover() { } let spawn_result = async { - Command::new(&prover_path) + let mut child = Command::new(&prover_path) .arg("start-prover") .spawn() .unwrap_or_else(|error| panic!("Failed to start prover process: {error}")); - if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); + match wait_for_prover_health( + STARTUP_HEALTH_CHECK_RETRIES, + Duration::from_secs(1), + &mut child, + ) + .await + { + Ok(()) => { + monitor_prover_child(child); + info!("Prover started successfully"); + } + Err(error) => { + let _ = child.kill(); + let _ = child.wait(); + panic!("Failed to start prover: {error}"); + } } } .await; From 1e32b49a9278be4df6c63644c4d0633fc3156c99 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 17:45:15 +0000 Subject: [PATCH 08/14] fix: harden test indexer proof parsing --- .../program-test/src/indexer/test_indexer.rs | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 6e61c3bd98..9abe6a80cc 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -1,4 +1,10 @@ -use std::{collections::HashMap, fmt::Debug, time::Duration}; +use std::{ + any::Any, + collections::HashMap, + fmt::Debug, + panic::{catch_unwind, AssertUnwindSafe}, + time::Duration, +}; #[cfg(feature = "devenv")] use account_compression::{ @@ -95,6 +101,37 @@ use crate::accounts::{ }; use crate::indexer::TestIndexerExtensions; +fn panic_payload_message(payload: &(dyn Any + Send)) -> String { + if let Some(message) = payload.downcast_ref::() { + message.clone() + } else if let Some(message) = payload.downcast_ref::<&str>() { + (*message).to_string() + } else { + "non-string panic payload".to_string() + } +} + +fn build_compressed_proof(body: &str) -> Result { + let proof_json = deserialize_gnark_proof_json(body) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = catch_unwind(AssertUnwindSafe(|| { + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + compress_proof(&proof_a, &proof_b, &proof_c) + })) + .map_err(|payload| { + IndexerError::CustomError(format!( + "failed to parse prover proof payload: {}", + panic_payload_message(payload.as_ref()) + )) + })?; + + Ok(CompressedProof { + a: proof_a, + b: proof_b, + c: proof_c, + }) +} + #[derive(Debug)] pub struct TestIndexer { pub state_merkle_trees: Vec, @@ -2584,20 +2621,10 @@ impl TestIndexer { })?; if status.is_success() { - let proof_json = deserialize_gnark_proof_json(&body) - .map_err(|error| IndexerError::CustomError(error.to_string()))?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); return Ok(ValidityProofWithContext { accounts: account_proof_inputs, addresses: address_proof_inputs, - proof: CompressedProof { - a: proof_a, - b: proof_b, - c: proof_c, - } - .into(), + proof: build_compressed_proof(&body)?.into(), }); } From d4367678f7a9a4aacdd6c2516aee57f14910fe45 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Tue, 17 Mar 2026 08:32:34 +0000 Subject: [PATCH 09/14] fix: harden runtime safety edge cases --- forester/src/forester_status.rs | 32 +++-- forester/src/processor/v2/helpers.rs | 81 ++++++++---- .../actions/legacy/instructions/transfer2.rs | 4 +- program-tests/utils/src/e2e_test_env.rs | 7 +- prover/client/src/proof.rs | 121 +++++++++++++----- prover/client/src/proof_client.rs | 4 +- .../batch_address_append/proof_inputs.rs | 52 +++++--- .../proof_types/batch_update/proof_inputs.rs | 22 +++- sdk-libs/client/src/indexer/types/proof.rs | 42 ++++-- .../client/src/interface/initialize_config.rs | 12 +- sdk-libs/client/src/interface/instructions.rs | 4 +- .../client/src/interface/load_accounts.rs | 6 +- sdk-libs/client/src/interface/pack.rs | 7 +- sdk-libs/client/src/local_test_validator.rs | 37 +++--- sdk-libs/client/src/utils.rs | 27 ++-- .../program-test/src/indexer/test_indexer.rs | 35 +---- 16 files changed, 305 insertions(+), 188 deletions(-) diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index 80c4539075..038ac46ef4 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -671,19 +671,21 @@ fn parse_tree_status( let (queue_len, queue_cap) = queue_account .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse StateV1 queue hash set"); + (None, None) + } + } }) - .map(|(l, c)| (Some(l), Some(c))) .unwrap_or((None, None)); ( @@ -726,19 +728,21 @@ fn parse_tree_status( let (queue_len, queue_cap) = queue_account .map(|acc| { - unsafe { parse_hash_set_from_bytes::(&acc.data) } - .ok() - .map(|hs| { + match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + Ok(hs) => { let len = hs .iter() .filter(|(_, cell)| cell.sequence_number.is_none()) .count() as u64; let cap = hs.get_capacity() as u64; - (len, cap) - }) - .unwrap_or((0, 0)) + (Some(len), Some(cap)) + } + Err(error) => { + warn!(?error, "Failed to parse AddressV1 queue hash set"); + (None, None) + } + } }) - .map(|(l, c)| (Some(l), Some(c))) .unwrap_or((None, None)); ( diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index a9fa2e290d..dd79ec9901 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -496,30 +496,57 @@ impl StreamingAddressQueue { if available < end || start >= end { return Ok(None); } - let actual_end = end; let data = lock_recover(&self.data, "streaming_address_queue.data"); - - for (name, len) in [ - ("addresses", data.addresses.len()), - ("low_element_values", data.low_element_values.len()), - ("low_element_next_values", data.low_element_next_values.len()), - ("low_element_indices", data.low_element_indices.len()), - ("low_element_next_indices", data.low_element_next_indices.len()), - ] { - if len < actual_end { - return Err(anyhow!( - "incomplete batch data: {} len {} < required end {}", - name, - len, - actual_end - )); - } - } - - let addresses = data.addresses[start..actual_end].to_vec(); + let Some(addresses) = data.addresses.get(start..end).map(|slice| slice.to_vec()) else { + return Ok(None); + }; if addresses.is_empty() { - return Err(anyhow!("Empty batch at start={}", start)); + return Ok(None); + } + let expected_len = addresses.len(); + let Some(low_element_values) = data + .low_element_values + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_next_values) = data + .low_element_next_values + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_indices) = data + .low_element_indices + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + let Some(low_element_next_indices) = data + .low_element_next_indices + .get(start..end) + .map(|slice| slice.to_vec()) + else { + return Ok(None); + }; + if [ + low_element_values.len(), + low_element_next_values.len(), + low_element_indices.len(), + low_element_next_indices.len(), + ] + .iter() + .any(|&len| len != expected_len) + { + return Ok(None); } + let low_element_proofs = match data.reconstruct_proofs::(start..end) { + Ok(proofs) if proofs.len() == expected_len => proofs, + Ok(_) | Err(_) => return Ok(None), + }; let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { Some(hashchain) => hashchain, @@ -541,13 +568,11 @@ impl StreamingAddressQueue { }; Ok(Some(AddressBatchSnapshot { - low_element_values: data.low_element_values[start..actual_end].to_vec(), - low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), - low_element_indices: data.low_element_indices[start..actual_end].to_vec(), - low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - low_element_proofs: data.reconstruct_proofs::(start..actual_end).map_err( - |error| anyhow!("incomplete batch data: failed to reconstruct proofs: {error}"), - )?, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, addresses, leaves_hashchain, })) diff --git a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs index 1ff92eeda9..00a55f3ad8 100644 --- a/program-tests/utils/src/actions/legacy/instructions/transfer2.rs +++ b/program-tests/utils/src/actions/legacy/instructions/transfer2.rs @@ -211,7 +211,9 @@ pub async fn create_generic_transfer2_instruction( let mut packed_tree_accounts = PackedAccounts::default(); // tree infos must be packed before packing the token input accounts - let packed_tree_infos = rpc_proof_result.pack_tree_infos(&mut packed_tree_accounts); + let packed_tree_infos = rpc_proof_result + .pack_tree_infos(&mut packed_tree_accounts) + .unwrap(); // We use a single shared output queue for all compress/compress-and-close operations to avoid ordering failures. let shared_output_queue = if packed_tree_infos.address_trees.is_empty() { diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 6c9fdb5d5e..c2c415ec9e 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -764,8 +764,7 @@ where // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = - address_queue.tree_next_insertion_index as usize; + let start_index = address_queue.tree_next_insertion_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" @@ -836,9 +835,9 @@ where .map_err(|error| RpcError::CustomError(error.to_string())) .unwrap(); let (proof_a, proof_b, proof_c) = - proof_from_json_struct(proof_json); + proof_from_json_struct(proof_json).unwrap(); let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); + compress_proof(&proof_a, &proof_b, &proof_c).unwrap(); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index c415a4d108..c88f2ea6cf 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -66,16 +66,28 @@ pub fn deserialize_gnark_proof_json(json_data: &str) -> serde_json::Result [u8; 32] { - let trimmed_str = hex_str.trim_start_matches("0x"); - let big_int = num_bigint::BigInt::from_str_radix(trimmed_str, 16).unwrap(); - let big_int_bytes = big_int.to_bytes_be().1; - if big_int_bytes.len() < 32 { +pub fn deserialize_hex_string_to_be_bytes(hex_str: &str) -> Result<[u8; 32], ProverClientError> { + let trimmed_str = hex_str + .strip_prefix("0x") + .or_else(|| hex_str.strip_prefix("0X")) + .unwrap_or(hex_str); + let big_uint = num_bigint::BigUint::from_str_radix(trimmed_str, 16).map_err(|error| { + ProverClientError::InvalidHexString(format!("{hex_str}: {error}")) + })?; + let big_uint_bytes = big_uint.to_bytes_be(); + if big_uint_bytes.len() > 32 { + return Err(ProverClientError::InvalidHexString(format!( + "{hex_str}: exceeds 32 bytes" + ))); + } + if big_uint_bytes.len() < 32 { let mut result = [0u8; 32]; - result[32 - big_int_bytes.len()..].copy_from_slice(&big_int_bytes); - result + result[32 - big_uint_bytes.len()..].copy_from_slice(&big_uint_bytes); + Ok(result) } else { - big_int_bytes.try_into().unwrap() + big_uint_bytes.try_into().map_err(|_| { + ProverClientError::InvalidHexString(format!("{hex_str}: invalid 32-byte encoding")) + }) } } @@ -83,47 +95,92 @@ pub fn compress_proof( proof_a: &[u8; 64], proof_b: &[u8; 128], proof_c: &[u8; 64], -) -> ([u8; 32], [u8; 64], [u8; 32]) { - let proof_a = alt_bn128_g1_compress(proof_a).unwrap(); - let proof_b = alt_bn128_g2_compress(proof_b).unwrap(); - let proof_c = alt_bn128_g1_compress(proof_c).unwrap(); - (proof_a, proof_b, proof_c) +) -> Result<([u8; 32], [u8; 64], [u8; 32]), ProverClientError> { + let proof_a = alt_bn128_g1_compress(proof_a)?; + let proof_b = alt_bn128_g2_compress(proof_b)?; + let proof_c = alt_bn128_g1_compress(proof_c)?; + Ok((proof_a, proof_b, proof_c)) } -pub fn proof_from_json_struct(json: GnarkProofJson) -> ([u8; 64], [u8; 128], [u8; 64]) { - let proof_a_x = deserialize_hex_string_to_be_bytes(&json.ar[0]); - let proof_a_y = deserialize_hex_string_to_be_bytes(&json.ar[1]); - let proof_a: [u8; 64] = [proof_a_x, proof_a_y].concat().try_into().unwrap(); - let proof_a = negate_g1(&proof_a); - let proof_b_x_0 = deserialize_hex_string_to_be_bytes(&json.bs[0][0]); - let proof_b_x_1 = deserialize_hex_string_to_be_bytes(&json.bs[0][1]); - let proof_b_y_0 = deserialize_hex_string_to_be_bytes(&json.bs[1][0]); - let proof_b_y_1 = deserialize_hex_string_to_be_bytes(&json.bs[1][1]); +pub fn proof_from_json_struct( + json: GnarkProofJson, +) -> Result<([u8; 64], [u8; 128], [u8; 64]), ProverClientError> { + let proof_a_x = deserialize_hex_string_to_be_bytes(json.ar.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A x coordinate".to_string()) + })?)?; + let proof_a_y = deserialize_hex_string_to_be_bytes(json.ar.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof A y coordinate".to_string()) + })?)?; + let proof_a: [u8; 64] = [proof_a_x, proof_a_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof A length".to_string()))?; + let proof_a = negate_g1(&proof_a)?; + let proof_b_x_0 = deserialize_hex_string_to_be_bytes( + json.bs + .first() + .and_then(|row| row.first()) + .ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x0 coordinate".to_string()) + })?, + )?; + let proof_b_x_1 = deserialize_hex_string_to_be_bytes( + json.bs + .first() + .and_then(|row| row.get(1)) + .ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x1 coordinate".to_string()) + })?, + )?; + let proof_b_y_0 = deserialize_hex_string_to_be_bytes( + json.bs + .get(1) + .and_then(|row| row.first()) + .ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B y0 coordinate".to_string()) + })?, + )?; + let proof_b_y_1 = deserialize_hex_string_to_be_bytes( + json.bs + .get(1) + .and_then(|row| row.get(1)) + .ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B y1 coordinate".to_string()) + })?, + )?; let proof_b: [u8; 128] = [proof_b_x_0, proof_b_x_1, proof_b_y_0, proof_b_y_1] .concat() .try_into() - .unwrap(); + .map_err(|_| ProverClientError::InvalidProofData("invalid proof B length".to_string()))?; - let proof_c_x = deserialize_hex_string_to_be_bytes(&json.krs[0]); - let proof_c_y = deserialize_hex_string_to_be_bytes(&json.krs[1]); - let proof_c: [u8; 64] = [proof_c_x, proof_c_y].concat().try_into().unwrap(); - (proof_a, proof_b, proof_c) + let proof_c_x = deserialize_hex_string_to_be_bytes(json.krs.first().ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C x coordinate".to_string()) + })?)?; + let proof_c_y = deserialize_hex_string_to_be_bytes(json.krs.get(1).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof C y coordinate".to_string()) + })?)?; + let proof_c: [u8; 64] = [proof_c_x, proof_c_y] + .concat() + .try_into() + .map_err(|_| ProverClientError::InvalidProofData("invalid proof C length".to_string()))?; + Ok((proof_a, proof_b, proof_c)) } -pub fn negate_g1(g1_be: &[u8; 64]) -> [u8; 64] { +pub fn negate_g1(g1_be: &[u8; 64]) -> Result<[u8; 64], ProverClientError> { let g1_le = convert_endianness::<32, 64>(g1_be); - let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::No).unwrap(); + let g1: G1 = G1::deserialize_with_mode(g1_le.as_slice(), Compress::No, Validate::Yes) + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg = g1.neg(); let mut g1_neg_be = [0u8; 64]; g1_neg .x .serialize_with_mode(&mut g1_neg_be[..32], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; g1_neg .y .serialize_with_mode(&mut g1_neg_be[32..], Compress::No) - .unwrap(); + .map_err(|error| ProverClientError::InvalidProofData(error.to_string()))?; let g1_neg_be: [u8; 64] = convert_endianness::<32, 64>(&g1_neg_be); - g1_neg_be + Ok(g1_neg_be) } diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 820c9fe07d..b3dacf295f 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -654,8 +654,8 @@ impl ProofClient { ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e)) })?; - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json)?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c)?; Ok(ProofResult { proof: ProofCompressed { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index 231fa17ac5..f11b6fb0cb 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -217,6 +217,10 @@ pub fn get_batch_address_append_circuit_inputs( } let new_element_values = &new_element_values[..batch_len]; + let mut staged_changelog = changelog.clone(); + let mut staged_indexed_changelog = indexed_changelog.clone(); + let mut staged_sparse_merkle_tree = sparse_merkle_tree.clone(); + let initial_changelog_len = staged_changelog.len(); let mut new_root = [0u8; 32]; let mut low_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); let mut new_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); @@ -253,9 +257,9 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut patcher = ChangelogProofPatcher::new::(changelog); + let mut patcher = ChangelogProofPatcher::new::(&staged_changelog); - let is_first_batch = indexed_changelog.is_empty(); + let is_first_batch = staged_indexed_changelog.is_empty(); let mut expected_root_for_low = current_root; for i in 0..batch_len { @@ -291,7 +295,7 @@ pub fn get_batch_address_append_circuit_inputs( patch_indexed_changelogs( 0, &mut changelog_index, - indexed_changelog, + &mut staged_indexed_changelog, &mut low_element, &mut new_element, &mut low_element_next_value, @@ -383,7 +387,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_element.index, )?; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); low_element_circuit_merkle_proofs.push( merkle_proof .iter() @@ -396,10 +400,10 @@ pub fn get_batch_address_append_circuit_inputs( let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, proof: low_element_changelog_proof, - changelog_index: indexed_changelog.len(), //change_log_index, + changelog_index: staged_indexed_changelog.len(), //change_log_index, }; - indexed_changelog.push(low_element_changelog_entry); + staged_indexed_changelog.push(low_element_changelog_entry); { let new_element_next_value = low_element_next_value; @@ -409,10 +413,10 @@ pub fn get_batch_address_append_circuit_inputs( ProverClientError::GenericError(format!("Failed to hash new element: {}", e)) })?; - let sparse_root_before = sparse_merkle_tree.root(); - let sparse_next_idx_before = sparse_merkle_tree.get_next_index(); + let sparse_root_before = staged_sparse_merkle_tree.root(); + let sparse_next_idx_before = staged_sparse_merkle_tree.get_next_index(); - let mut merkle_proof_array = sparse_merkle_tree.append(new_element_leaf_hash); + let mut merkle_proof_array = staged_sparse_merkle_tree.append(new_element_leaf_hash); let current_index = next_index + i; @@ -424,7 +428,7 @@ pub fn get_batch_address_append_circuit_inputs( current_index, )?; - if i == 0 && changelog.len() == 1 { + if i == 0 && staged_changelog.len() == initial_changelog_len + 1 { if sparse_next_idx_before != current_index { return Err(ProverClientError::GenericError(format!( "sparse index mismatch: sparse tree next_index={} but expected current_index={}", @@ -483,7 +487,7 @@ pub fn get_batch_address_append_circuit_inputs( new_root = updated_root; - patcher.push_changelog_entry::(changelog, changelog_entry); + patcher.push_changelog_entry::(&mut staged_changelog, changelog_entry); new_element_circuit_merkle_proofs.push( merkle_proof_array .iter() @@ -501,9 +505,9 @@ pub fn get_batch_address_append_circuit_inputs( let new_element_changelog_entry = IndexedChangelogEntry { element: new_element_raw, proof: merkle_proof_array, - changelog_index: indexed_changelog.len(), + changelog_index: staged_indexed_changelog.len(), }; - indexed_changelog.push(new_element_changelog_entry); + staged_indexed_changelog.push(new_element_changelog_entry); } } @@ -539,18 +543,18 @@ pub fn get_batch_address_append_circuit_inputs( patcher.hits, patcher.misses, patcher.overwrites, - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); - if patcher.hits == 0 && !changelog.is_empty() { + if patcher.hits == 0 && !staged_changelog.is_empty() { tracing::warn!( "Address proof patcher had 0 cache hits despite non-empty changelog (changelog_len={}, indexed_changelog_len={})", - changelog.len(), - indexed_changelog.len() + staged_changelog.len(), + staged_indexed_changelog.len() ); } - Ok(BatchAddressAppendInputs { + let inputs = BatchAddressAppendInputs { batch_size: patched_low_element_values.len(), hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), low_element_values: patched_low_element_values @@ -570,7 +574,7 @@ pub fn get_batch_address_append_circuit_inputs( .map(|v| BigUint::from_bytes_be(v)) .collect(), low_element_proofs: low_element_circuit_merkle_proofs, - new_element_values: new_element_values[0..] + new_element_values: new_element_values .iter() .map(|v| BigUint::from_bytes_be(v)) .collect(), @@ -580,5 +584,11 @@ pub fn get_batch_address_append_circuit_inputs( public_input_hash: BigUint::from_bytes_be(&public_input_hash), start_index: next_index, tree_height: HEIGHT, - }) + }; + + *changelog = staged_changelog; + *indexed_changelog = staged_indexed_changelog; + *sparse_merkle_tree = staged_sparse_merkle_tree; + + Ok(inputs) } diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2ada02b92b..f5467184aa 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -31,8 +31,12 @@ pub struct BatchUpdateCircuitInputs { } impl BatchUpdateCircuitInputs { - pub fn public_inputs_arr(&self) -> [u8; 32] { - bigint_to_u8_32(&self.public_input_hash).unwrap() + pub fn public_inputs_arr(&self) -> Result<[u8; 32], ProverClientError> { + bigint_to_u8_32(&self.public_input_hash).map_err(|error| { + ProverClientError::GenericError(format!( + "failed to serialize batch update public input: {error}" + )) + }) } pub fn new( @@ -112,9 +116,17 @@ impl BatchUpdateCircuitInputs { pub struct BatchUpdateInputs<'a>(pub &'a [BatchUpdateCircuitInputs]); impl BatchUpdateInputs<'_> { - pub fn public_inputs(&self) -> Vec<[u8; 32]> { - // Concatenate all public inputs into a single flat vector - vec![self.0[0].public_inputs_arr()] + pub fn public_inputs(&self) -> Result, ProverClientError> { + if self.0.is_empty() { + return Err(ProverClientError::GenericError( + "batch update inputs cannot be empty".to_string(), + )); + } + + self.0 + .iter() + .map(BatchUpdateCircuitInputs::public_inputs_arr) + .collect() } } diff --git a/sdk-libs/client/src/indexer/types/proof.rs b/sdk-libs/client/src/indexer/types/proof.rs index 0b45e00986..1c858fd74d 100644 --- a/sdk-libs/client/src/indexer/types/proof.rs +++ b/sdk-libs/client/src/indexer/types/proof.rs @@ -189,7 +189,10 @@ pub struct PackedTreeInfos { } impl ValidityProofWithContext { - pub fn pack_tree_infos(&self, packed_accounts: &mut PackedAccounts) -> PackedTreeInfos { + pub fn pack_tree_infos( + &self, + packed_accounts: &mut PackedAccounts, + ) -> Result { let mut packed_tree_infos = Vec::new(); let mut address_trees = Vec::new(); let mut output_tree_index = None; @@ -211,19 +214,28 @@ impl ValidityProofWithContext { if let Some(next) = account.tree_info.next_tree_info { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = next.pack_output_tree_index(packed_accounts).unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = next.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } else { // SAFETY: account will always have a state Merkle tree context. // pack_output_tree_index only panics on an address Merkle tree context. - let index = account - .tree_info - .pack_output_tree_index(packed_accounts) - .unwrap(); - if output_tree_index.is_none() { - output_tree_index = Some(index); + let index = account.tree_info.pack_output_tree_index(packed_accounts)?; + match output_tree_index { + Some(existing) if existing != index => { + return Err(IndexerError::InvalidParameters(format!( + "mixed output tree indices in state proof: {existing} != {index}" + ))); + } + Some(_) => {} + None => output_tree_index = Some(index), } } } @@ -244,13 +256,17 @@ impl ValidityProofWithContext { } else { Some(PackedStateTreeInfos { packed_tree_infos, - output_tree_index: output_tree_index.unwrap(), + output_tree_index: output_tree_index.ok_or_else(|| { + IndexerError::InvalidParameters( + "missing output tree index for non-empty state proof".to_string(), + ) + })?, }) }; - PackedTreeInfos { + Ok(PackedTreeInfos { state_trees: packed_tree_infos, address_trees, - } + }) } pub fn from_api_model( diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 7b5919cdb1..9fbeacfe89 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -7,6 +7,8 @@ use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSeria use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; +use crate::interface::instructions::INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR; + /// Default address tree v2 pubkey. pub const ADDRESS_TREE_V2: Pubkey = solana_pubkey::pubkey!("amt2kaJA14v3urZbZvnc5v2np8jqvc4Z8zDep5wbtzx"); @@ -115,16 +117,14 @@ impl InitializeRentFreeConfig { address_space: self.address_space, }; - // Anchor discriminator for "initialize_compression_config" - // SHA256("global:initialize_compression_config")[..8] - const DISCRIMINATOR: [u8; 8] = [133, 228, 12, 169, 56, 76, 222, 61]; - let serialized_data = instruction_data .try_to_vec() .expect("Failed to serialize instruction data"); - let mut data = Vec::with_capacity(DISCRIMINATOR.len() + serialized_data.len()); - data.extend_from_slice(&DISCRIMINATOR); + let mut data = Vec::with_capacity( + INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR.len() + serialized_data.len(), + ); + data.extend_from_slice(&INITIALIZE_COMPRESSION_CONFIG_DISCRIMINATOR); data.extend_from_slice(&serialized_data); let instruction = Instruction { diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index e80e7b72c1..41c06e637f 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -234,7 +234,7 @@ where let output_queue = get_output_queue(&cold_accounts[0].0.tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts)?; let tree_infos = &packed_tree_infos .state_trees .as_ref() @@ -313,7 +313,7 @@ pub fn build_compress_accounts_idempotent( let output_queue = get_output_queue(&proof.accounts[0].tree_info); let output_state_tree_index = remaining_accounts.insert_or_get(output_queue); - let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts); + let packed_tree_infos = proof.pack_tree_infos(&mut remaining_accounts)?; let tree_infos = packed_tree_infos .state_trees .as_ref() diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 0d564734a2..4fccd73810 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -220,7 +220,7 @@ fn group_pda_specs<'a, V>( specs: &[&'a PdaSpec], max_per_group: usize, ) -> Vec>> { - assert!(max_per_group > 0, "max_per_group must be non-zero"); + debug_assert!(max_per_group > 0, "max_per_group must be non-zero"); if specs.is_empty() { return Vec::new(); } @@ -424,7 +424,9 @@ fn build_transfer2( fee_payer: Pubkey, ) -> Result { let mut packed = PackedAccounts::default(); - let packed_trees = proof.pack_tree_infos(&mut packed); + let packed_trees = proof + .pack_tree_infos(&mut packed) + .map_err(|error| LoadAccountsError::BuildInstruction(error.to_string()))?; let tree_infos = packed_trees .state_trees .as_ref() diff --git a/sdk-libs/client/src/interface/pack.rs b/sdk-libs/client/src/interface/pack.rs index 804a48751d..97505adabe 100644 --- a/sdk-libs/client/src/interface/pack.rs +++ b/sdk-libs/client/src/interface/pack.rs @@ -12,6 +12,9 @@ use crate::indexer::{TreeInfo, ValidityProofWithContext}; pub enum PackError { #[error("Failed to add system accounts: {0}")] SystemAccounts(#[from] light_sdk::error::LightSdkError), + + #[error("Failed to pack tree infos: {0}")] + Indexer(#[from] crate::indexer::IndexerError), } /// Packed state tree infos from validity proof. @@ -87,7 +90,7 @@ fn pack_proof_internal( // For mint creation: pack address tree first (index 1), then state tree. let (client_packed_tree_infos, state_tree_index) = if include_state_tree { // Pack tree infos first to ensure address tree is at index 1 - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; // Then add state tree (will be after address tree) let state_tree = output_tree @@ -99,7 +102,7 @@ fn pack_proof_internal( (tree_infos, Some(state_idx)) } else { - let tree_infos = proof.pack_tree_infos(&mut packed); + let tree_infos = proof.pack_tree_infos(&mut packed)?; (tree_infos, None) }; let (remaining_accounts, system_offset, _) = packed.to_account_metas(); diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index 36ed7c04b3..4370f4911a 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -1,6 +1,7 @@ -use std::process::{Command, Stdio}; +use std::process::Stdio; use light_prover_client::helpers::get_project_root; +use tokio::process::Command; /// Configuration for an upgradeable program to deploy to the validator. #[derive(Debug, Clone)] @@ -57,25 +58,25 @@ impl Default for LightValidatorConfig { pub async fn spawn_validator(config: LightValidatorConfig) { if let Some(project_root) = get_project_root() { - let path = "cli/test_bin/run test-validator"; - let mut path = format!("{}/{}", project_root.trim(), path); + let command = "cli/test_bin/run test-validator"; + let mut command = format!("{}/{}", project_root.trim(), command); if !config.enable_indexer { - path.push_str(" --skip-indexer"); + command.push_str(" --skip-indexer"); } if let Some(limit_ledger_size) = config.limit_ledger_size { - path.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); + command.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); } for sbf_program in config.sbf_programs.iter() { - path.push_str(&format!( + command.push_str(&format!( " --sbf-program {} {}", sbf_program.0, sbf_program.1 )); } for upgradeable_program in config.upgradeable_programs.iter() { - path.push_str(&format!( + command.push_str(&format!( " --upgradeable-program {} {} {}", upgradeable_program.program_id, upgradeable_program.program_path, @@ -84,18 +85,18 @@ pub async fn spawn_validator(config: LightValidatorConfig) { } if !config.enable_prover { - path.push_str(" --skip-prover"); + command.push_str(" --skip-prover"); } if config.use_surfpool { - path.push_str(" --use-surfpool"); + command.push_str(" --use-surfpool"); } for arg in config.validator_args.iter() { - path.push_str(&format!(" {}", arg)); + command.push_str(&format!(" {}", arg)); } - println!("Starting validator with command: {}", path); + println!("Starting validator with command: {}", command); if config.use_surfpool { // The CLI starts surfpool, prover, and photon, then exits once all @@ -103,24 +104,28 @@ pub async fn spawn_validator(config: LightValidatorConfig) { // is up before the test proceeds. let mut child = Command::new("sh") .arg("-c") - .arg(path) + .arg(command) .stdin(Stdio::null()) .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) .spawn() .expect("Failed to start server process"); - let status = child.wait().expect("Failed to wait for CLI process"); + let status = child + .wait() + .await + .expect("Failed to wait for CLI process"); assert!(status.success(), "CLI exited with error: {}", status); } else { - let child = Command::new("sh") + let _child = Command::new("sh") .arg("-c") - .arg(path) + .arg(command) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn() .expect("Failed to start server process"); - std::mem::drop(child); + // Intentionally detaching the spawned child; the caller only waits + // for the validator services to become available. tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; } } diff --git a/sdk-libs/client/src/utils.rs b/sdk-libs/client/src/utils.rs index b8f2e05ecb..0055f8dbea 100644 --- a/sdk-libs/client/src/utils.rs +++ b/sdk-libs/client/src/utils.rs @@ -15,8 +15,11 @@ pub fn find_light_bin() -> Option { if !output.status.success() { return None; } - // Convert the output into a string (removing any trailing newline) - let light_path = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let light_path = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); // Get the parent directory of the 'light' binary let mut light_bin_path = PathBuf::from(light_path); light_bin_path.pop(); // Remove the 'light' binary itself @@ -30,16 +33,16 @@ pub fn find_light_bin() -> Option { #[cfg(feature = "devenv")] { println!("Use only in light protocol monorepo. Using 'git rev-parse --show-toplevel' to find the location of 'light' binary"); - let light_protocol_toplevel = String::from_utf8_lossy( - &std::process::Command::new("git") - .arg("rev-parse") - .arg("--show-toplevel") - .output() - .expect("Failed to get top-level directory") - .stdout, - ) - .trim() - .to_string(); + let output = std::process::Command::new("git") + .arg("rev-parse") + .arg("--show-toplevel") + .output() + .expect("Failed to get top-level directory"); + let light_protocol_toplevel = std::str::from_utf8(&output.stdout) + .ok()? + .trim_end_matches("\r\n") + .trim_end_matches('\n') + .to_string(); let light_path = PathBuf::from(format!("{}/target/deploy/", light_protocol_toplevel)); Some(light_path) } diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 9abe6a80cc..7618e045f3 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -1,10 +1,4 @@ -use std::{ - any::Any, - collections::HashMap, - fmt::Debug, - panic::{catch_unwind, AssertUnwindSafe}, - time::Duration, -}; +use std::{collections::HashMap, fmt::Debug, time::Duration}; #[cfg(feature = "devenv")] use account_compression::{ @@ -101,29 +95,13 @@ use crate::accounts::{ }; use crate::indexer::TestIndexerExtensions; -fn panic_payload_message(payload: &(dyn Any + Send)) -> String { - if let Some(message) = payload.downcast_ref::() { - message.clone() - } else if let Some(message) = payload.downcast_ref::<&str>() { - (*message).to_string() - } else { - "non-string panic payload".to_string() - } -} - fn build_compressed_proof(body: &str) -> Result { let proof_json = deserialize_gnark_proof_json(body) .map_err(|error| IndexerError::CustomError(error.to_string()))?; - let (proof_a, proof_b, proof_c) = catch_unwind(AssertUnwindSafe(|| { - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - compress_proof(&proof_a, &proof_b, &proof_c) - })) - .map_err(|payload| { - IndexerError::CustomError(format!( - "failed to parse prover proof payload: {}", - panic_payload_message(payload.as_ref()) - )) - })?; + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; Ok(CompressedProof { a: proof_a, @@ -2110,9 +2088,10 @@ impl TestIndexer { #[cfg(all(test, feature = "v2"))] mod tests { - use super::*; use light_compressed_account::compressed_account::CompressedAccount; + use super::*; + fn queued_account( owner: [u8; 32], merkle_tree: Pubkey, From f333ac1f2f3295ee266f1858226868bae1687857 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Tue, 17 Mar 2026 08:34:12 +0000 Subject: [PATCH 10/14] fix: harden runtime safety fallout --- forester/src/forester_status.rs | 16 ++--- forester/src/processor/v2/proof_worker.rs | 4 +- .../utils/src/mock_batched_forester.rs | 4 +- prover/client/src/proof.rs | 51 ++++++-------- prover/client/tests/batch_address_append.rs | 67 +++++++++---------- sdk-libs/client/src/local_test_validator.rs | 5 +- .../sdk-anchor-test/tests/read_only.rs | 12 +++- .../programs/sdk-anchor-test/tests/test.rs | 8 ++- sdk-tests/sdk-native-test/tests/test.rs | 6 +- sdk-tests/sdk-pinocchio-v1-test/tests/test.rs | 6 +- 10 files changed, 93 insertions(+), 86 deletions(-) diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index 038ac46ef4..d8f958134d 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -670,8 +670,8 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { Ok(hs) => { let len = hs .iter() @@ -684,8 +684,8 @@ fn parse_tree_status( warn!(?error, "Failed to parse StateV1 queue hash set"); (None, None) } - } - }) + }, + ) .unwrap_or((None, None)); ( @@ -727,8 +727,8 @@ fn parse_tree_status( let fullness = next_index as f64 / capacity as f64 * 100.0; let (queue_len, queue_cap) = queue_account - .map(|acc| { - match unsafe { parse_hash_set_from_bytes::(&acc.data) } { + .map( + |acc| match unsafe { parse_hash_set_from_bytes::(&acc.data) } { Ok(hs) => { let len = hs .iter() @@ -741,8 +741,8 @@ fn parse_tree_status( warn!(?error, "Failed to parse AddressV1 queue hash set"); (None, None) } - } - }) + }, + ) .unwrap_or((None, None)); ( diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index ded9fcedc8..603fa3f19b 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -164,7 +164,9 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> crate::Result> { +pub fn spawn_proof_workers( + config: &ProverConfig, +) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); let clients = Arc::new(ProofClients::new(config)?); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 8ea51b7169..f3ad76cdbe 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -209,9 +209,7 @@ impl MockBatchedForester { &[], )?; let proof_client = ProofClient::local()?; - let proof_result = proof_client - .generate_batch_update_proof(inputs) - .await?; + let proof_result = proof_client.generate_batch_update_proof(inputs).await?; let new_root = self.merkle_tree.root(); let proof = CompressedProof { a: proof_result.0.proof.a, diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index c88f2ea6cf..c5b5847815 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -12,6 +12,9 @@ use solana_bn254::compression::prelude::{ convert_endianness, }; +pub type CompressedProofBytes = ([u8; 32], [u8; 64], [u8; 32]); +pub type UncompressedProofBytes = ([u8; 64], [u8; 128], [u8; 64]); + #[derive(Debug, Clone, Copy)] pub struct ProofCompressed { pub a: [u8; 32], @@ -71,9 +74,8 @@ pub fn deserialize_hex_string_to_be_bytes(hex_str: &str) -> Result<[u8; 32], Pro .strip_prefix("0x") .or_else(|| hex_str.strip_prefix("0X")) .unwrap_or(hex_str); - let big_uint = num_bigint::BigUint::from_str_radix(trimmed_str, 16).map_err(|error| { - ProverClientError::InvalidHexString(format!("{hex_str}: {error}")) - })?; + let big_uint = num_bigint::BigUint::from_str_radix(trimmed_str, 16) + .map_err(|error| ProverClientError::InvalidHexString(format!("{hex_str}: {error}")))?; let big_uint_bytes = big_uint.to_bytes_be(); if big_uint_bytes.len() > 32 { return Err(ProverClientError::InvalidHexString(format!( @@ -95,7 +97,7 @@ pub fn compress_proof( proof_a: &[u8; 64], proof_b: &[u8; 128], proof_c: &[u8; 64], -) -> Result<([u8; 32], [u8; 64], [u8; 32]), ProverClientError> { +) -> Result { let proof_a = alt_bn128_g1_compress(proof_a)?; let proof_b = alt_bn128_g2_compress(proof_b)?; let proof_c = alt_bn128_g1_compress(proof_c)?; @@ -104,7 +106,7 @@ pub fn compress_proof( pub fn proof_from_json_struct( json: GnarkProofJson, -) -> Result<([u8; 64], [u8; 128], [u8; 64]), ProverClientError> { +) -> Result { let proof_a_x = deserialize_hex_string_to_be_bytes(json.ar.first().ok_or_else(|| { ProverClientError::InvalidProofData("missing proof A x coordinate".to_string()) })?)?; @@ -117,37 +119,24 @@ pub fn proof_from_json_struct( .map_err(|_| ProverClientError::InvalidProofData("invalid proof A length".to_string()))?; let proof_a = negate_g1(&proof_a)?; let proof_b_x_0 = deserialize_hex_string_to_be_bytes( - json.bs - .first() - .and_then(|row| row.first()) - .ok_or_else(|| { - ProverClientError::InvalidProofData("missing proof B x0 coordinate".to_string()) - })?, + json.bs.first().and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x0 coordinate".to_string()) + })?, )?; let proof_b_x_1 = deserialize_hex_string_to_be_bytes( - json.bs - .first() - .and_then(|row| row.get(1)) - .ok_or_else(|| { - ProverClientError::InvalidProofData("missing proof B x1 coordinate".to_string()) - })?, + json.bs.first().and_then(|row| row.get(1)).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B x1 coordinate".to_string()) + })?, )?; let proof_b_y_0 = deserialize_hex_string_to_be_bytes( - json.bs - .get(1) - .and_then(|row| row.first()) - .ok_or_else(|| { - ProverClientError::InvalidProofData("missing proof B y0 coordinate".to_string()) - })?, - )?; - let proof_b_y_1 = deserialize_hex_string_to_be_bytes( - json.bs - .get(1) - .and_then(|row| row.get(1)) - .ok_or_else(|| { - ProverClientError::InvalidProofData("missing proof B y1 coordinate".to_string()) - })?, + json.bs.get(1).and_then(|row| row.first()).ok_or_else(|| { + ProverClientError::InvalidProofData("missing proof B y0 coordinate".to_string()) + })?, )?; + let proof_b_y_1 = + deserialize_hex_string_to_be_bytes(json.bs.get(1).and_then(|row| row.get(1)).ok_or_else( + || ProverClientError::InvalidProofData("missing proof B y1 coordinate".to_string()), + )?)?; let proof_b: [u8; 128] = [proof_b_x_0, proof_b_x_1, proof_b_y_0, proof_b_y_1] .concat() .try_into() diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 8a02c363ff..7b8ceaa5f9 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -38,40 +38,39 @@ async fn prove_batch_address_append() { IndexedMerkleTree::::new(DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize, 0) .unwrap(); - let collect_non_inclusion_data = - |tree: &IndexedMerkleTree, values: &[BigUint]| { - let mut low_element_values = Vec::with_capacity(values.len()); - let mut low_element_indices = Vec::with_capacity(values.len()); - let mut low_element_next_indices = Vec::with_capacity(values.len()); - let mut low_element_next_values = Vec::with_capacity(values.len()); - let mut low_element_proofs: Vec< - [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize], - > = Vec::with_capacity(values.len()); - - for new_element_value in values { - let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); - - low_element_values.push(non_inclusion_proof.leaf_lower_range_value); - low_element_indices.push(non_inclusion_proof.leaf_index); - low_element_next_indices.push(non_inclusion_proof.next_index); - low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push( - non_inclusion_proof - .merkle_proof - .as_slice() - .try_into() - .unwrap(), - ); - } - - ( - low_element_values, - low_element_indices, - low_element_next_indices, - low_element_next_values, - low_element_proofs, - ) - }; + let collect_non_inclusion_data = |tree: &IndexedMerkleTree, + values: &[BigUint]| { + let mut low_element_values = Vec::with_capacity(values.len()); + let mut low_element_indices = Vec::with_capacity(values.len()); + let mut low_element_next_indices = Vec::with_capacity(values.len()); + let mut low_element_next_values = Vec::with_capacity(values.len()); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::with_capacity(values.len()); + + for new_element_value in values { + let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); + + low_element_values.push(non_inclusion_proof.leaf_lower_range_value); + low_element_indices.push(non_inclusion_proof.leaf_index); + low_element_next_indices.push(non_inclusion_proof.next_index); + low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); + } + + ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) + }; let initial_start_index = relayer_merkle_tree.merkle_tree.rightmost_index; let initial_root = relayer_merkle_tree.root(); diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index 4370f4911a..b27daa6a25 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -110,10 +110,7 @@ pub async fn spawn_validator(config: LightValidatorConfig) { .stderr(Stdio::inherit()) .spawn() .expect("Failed to start server process"); - let status = child - .wait() - .await - .expect("Failed to wait for CLI process"); + let status = child.wait().await.expect("Failed to wait for CLI process"); assert!(status.success(), "CLI exited with error: {}", status); } else { let _child = Command::new("sh") diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs index 154f4e2045..3e3fb4934d 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/read_only.rs @@ -127,7 +127,9 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -178,6 +180,7 @@ async fn read_sha256_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -231,6 +234,7 @@ async fn read_sha256_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -289,7 +293,9 @@ async fn create_compressed_account_poseidon( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -340,6 +346,7 @@ async fn read_poseidon_light_system_cpi( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -393,6 +400,7 @@ async fn read_poseidon_lowlevel( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs index e19d0742de..e5cde869bd 100644 --- a/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs +++ b/sdk-tests/sdk-anchor-test/programs/sdk-anchor-test/tests/test.rs @@ -171,7 +171,9 @@ async fn create_compressed_account( ) .await? .value; - let packed_accounts = rpc_result.pack_tree_infos(&mut remaining_accounts); + let packed_accounts = rpc_result + .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))?; let output_tree_index = rpc .get_random_state_tree_info() @@ -223,6 +225,7 @@ async fn update_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -277,6 +280,7 @@ async fn close_compressed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -340,6 +344,7 @@ async fn reinit_closed_account( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); @@ -388,6 +393,7 @@ async fn close_compressed_account_permanent( let packed_tree_accounts = rpc_result .pack_tree_infos(&mut remaining_accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-native-test/tests/test.rs b/sdk-tests/sdk-native-test/tests/test.rs index 30d792487f..eb81e8cf47 100644 --- a/sdk-tests/sdk-native-test/tests/test.rs +++ b/sdk-tests/sdk-native-test/tests/test.rs @@ -103,7 +103,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { @@ -147,6 +150,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); diff --git a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs index 0ae7f5c029..83e205bbf4 100644 --- a/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs +++ b/sdk-tests/sdk-pinocchio-v1-test/tests/test.rs @@ -101,7 +101,10 @@ pub async fn create_pda( .value; let output_merkle_tree_index = accounts.insert_or_get(*merkle_tree_pubkey); - let packed_address_tree_info = rpc_result.pack_tree_infos(&mut accounts).address_trees[0]; + let packed_address_tree_info = rpc_result + .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? + .address_trees[0]; let (accounts, system_accounts_offset, tree_accounts_offset) = accounts.to_account_metas(); let instruction_data = CreatePdaInstructionData { proof: rpc_result.proof, @@ -145,6 +148,7 @@ pub async fn update_pda( let packed_accounts = rpc_result .pack_tree_infos(&mut accounts) + .map_err(|error| RpcError::CustomError(format!("Failed to pack tree infos: {error}")))? .state_trees .unwrap(); From 249149d08a6f07c0163bb9e522efd19c3f93e9de Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:35:36 +0000 Subject: [PATCH 11/14] feat: isolate and track forester worker concurrency --- forester/src/epoch_manager.rs | 815 ++++++++++++------ forester/src/metrics.rs | 6 +- forester/src/priority_fee.rs | 2 +- forester/tests/e2e_test.rs | 27 +- .../batched_state_async_indexer_test.rs | 2 +- forester/tests/legacy/test_utils.rs | 2 +- forester/tests/test_batch_append_spent.rs | 25 +- forester/tests/test_indexer_interface.rs | 3 +- forester/tests/test_utils.rs | 2 +- 9 files changed, 580 insertions(+), 304 deletions(-) diff --git a/forester/src/epoch_manager.rs b/forester/src/epoch_manager.rs index f52efa1b13..c1ee6544fc 100644 --- a/forester/src/epoch_manager.rs +++ b/forester/src/epoch_manager.rs @@ -1,4 +1,5 @@ use std::{ + any::Any, collections::HashMap, sync::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, @@ -14,7 +15,7 @@ use forester_utils::{ forester_epoch::{get_epoch_phases, Epoch, ForesterSlot, TreeAccounts, TreeForesterSchedule}, rpc_pool::SolanaRpcPool, }; -use futures::future::join_all; +use futures::{future::join_all, stream::FuturesUnordered, FutureExt, StreamExt}; use light_client::{ indexer::{Indexer, MerkleProof, NewAddressProofWithContext}, rpc::{LightClient, LightClientConfig, RetryConfig, Rpc, RpcError}, @@ -39,7 +40,8 @@ use solana_sdk::{ transaction::TransactionError, }; use tokio::{ - sync::{mpsc, oneshot, Mutex}, + runtime::Handle, + sync::{mpsc, oneshot, Mutex, Semaphore}, task::JoinHandle, time::{sleep, Instant, MissedTickBehavior}, }; @@ -94,8 +96,6 @@ type StateBatchProcessorMap = type AddressBatchProcessorMap = Arc>>)>>; type ProcessorInitLockMap = Arc>>>; -type TreeProcessingTask = JoinHandle>; - /// Coordinates re-finalization across parallel `process_queue` tasks when new /// foresters register mid-epoch. Only one task performs the on-chain /// `finalize_registration` tx; others wait for it to complete. @@ -221,6 +221,46 @@ impl std::ops::AddAssign for ProcessingMetrics { } } +fn panic_payload_message(payload: &(dyn Any + Send)) -> String { + if let Some(message) = payload.downcast_ref::() { + message.clone() + } else if let Some(message) = payload.downcast_ref::<&'static str>() { + (*message).to_string() + } else { + "non-string panic payload".to_string() + } +} + +const NEW_TREE_WORKER_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); + +fn max_parallel_tree_workers(tree_count: usize) -> usize { + if tree_count == 0 { + return 0; + } + + let cpu_count = std::thread::available_parallelism() + .map(|parallelism| parallelism.get()) + .unwrap_or(4); + tree_count.min(std::cmp::max(1, cpu_count / 2)) +} + +struct NewTreeWorker { + tree: Pubkey, + epoch: u64, + cancel: Option>, + completion: oneshot::Receiver<()>, + thread_handle: std::thread::JoinHandle<()>, +} + +impl std::fmt::Debug for NewTreeWorker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NewTreeWorker") + .field("tree", &self.tree) + .field("epoch", &self.epoch) + .finish_non_exhaustive() + } +} + #[derive(Copy, Clone, Debug)] pub struct WorkReport { pub epoch: u64, @@ -280,6 +320,7 @@ pub struct EpochManager { run_id: Arc, /// Per-epoch registration trackers to coordinate re-finalization when new foresters register mid-epoch registration_trackers: Arc>>, + new_tree_workers: Arc>>, } impl Clone for EpochManager { @@ -310,6 +351,7 @@ impl Clone for EpochManager { heartbeat: self.heartbeat.clone(), run_id: self.run_id.clone(), registration_trackers: self.registration_trackers.clone(), + new_tree_workers: self.new_tree_workers.clone(), } } } @@ -359,9 +401,129 @@ impl EpochManager { heartbeat, run_id: Arc::::from(run_id), registration_trackers: Arc::new(DashMap::new()), + new_tree_workers: Arc::new(Mutex::new(Vec::new())), }) } + fn join_new_tree_worker_with_run_id(run_id: Arc, worker: NewTreeWorker) { + if let Err(payload) = worker.thread_handle.join() { + error!( + event = "new_tree_worker_join_panicked", + run_id = %run_id, + tree = %worker.tree, + epoch = worker.epoch, + panic = %panic_payload_message(payload.as_ref()), + "New tree worker panicked while joining" + ); + } + } + + fn join_new_tree_worker(&self, worker: NewTreeWorker) { + Self::join_new_tree_worker_with_run_id(self.run_id.clone(), worker); + } + + fn detach_new_tree_worker_join(&self, worker: NewTreeWorker) { + let run_id = self.run_id.clone(); + let tree = worker.tree; + let epoch = worker.epoch; + std::thread::spawn(move || { + warn!( + event = "new_tree_worker_join_deferred", + run_id = %run_id, + tree = %tree, + epoch, + "Deferring timed-out new-tree worker join to background thread" + ); + Self::join_new_tree_worker_with_run_id(run_id, worker); + }); + } + + async fn reap_finished_new_tree_workers(&self) { + let finished_workers = { + let mut workers = self.new_tree_workers.lock().await; + let mut pending = Vec::with_capacity(workers.len()); + let mut finished = Vec::new(); + + for worker in workers.drain(..) { + if worker.thread_handle.is_finished() { + finished.push(worker); + } else { + pending.push(worker); + } + } + + *workers = pending; + finished + }; + + for mut worker in finished_workers { + let _ = worker.completion.try_recv(); + self.join_new_tree_worker(worker); + } + } + + async fn register_new_tree_worker(&self, worker: NewTreeWorker) { + self.reap_finished_new_tree_workers().await; + self.new_tree_workers.lock().await.push(worker); + } + + async fn shutdown_new_tree_workers(&self, timeout_duration: Duration) { + let mut workers = { + let mut guard = self.new_tree_workers.lock().await; + std::mem::take(&mut *guard) + }; + + if workers.is_empty() { + return; + } + + info!( + event = "new_tree_workers_shutdown_started", + run_id = %self.run_id, + worker_count = workers.len(), + timeout_secs = timeout_duration.as_secs_f64(), + "Shutting down tracked new-tree workers" + ); + + for worker in &mut workers { + if let Some(cancel) = worker.cancel.take() { + let _ = cancel.send(()); + } + } + + let deadline = Instant::now() + timeout_duration; + for mut worker in workers { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + warn!( + event = "new_tree_worker_shutdown_timed_out", + run_id = %self.run_id, + tree = %worker.tree, + epoch = worker.epoch, + "Timed out waiting for new-tree worker shutdown" + ); + self.detach_new_tree_worker_join(worker); + continue; + } + + match tokio::time::timeout(remaining, &mut worker.completion).await { + Ok(Ok(())) | Ok(Err(_)) => { + self.join_new_tree_worker(worker); + } + Err(_) => { + warn!( + event = "new_tree_worker_shutdown_timed_out", + run_id = %self.run_id, + tree = %worker.tree, + epoch = worker.epoch, + "Timed out waiting for new-tree worker shutdown" + ); + self.detach_new_tree_worker_join(worker); + } + } + } + } + pub async fn run(self: Arc) -> Result<()> { let (tx, mut rx) = mpsc::channel(100); let tx = Arc::new(tx); @@ -411,8 +573,40 @@ impl EpochManager { }, ); + let mut epoch_tasks = FuturesUnordered::new(); let result = loop { tokio::select! { + Some((epoch, result)) = epoch_tasks.next(), if !epoch_tasks.is_empty() => { + match result { + Ok(Ok(())) => { + debug!( + event = "epoch_processing_completed", + run_id = %self.run_id, + epoch, + "Epoch processed successfully" + ); + } + Ok(Err(e)) => { + error!( + event = "epoch_processing_failed", + run_id = %self.run_id, + epoch, + error = ?e, + "Error processing epoch" + ); + } + Err(payload) => { + let payload: Box = payload; + error!( + event = "epoch_processing_panicked", + run_id = %self.run_id, + epoch, + panic = %panic_payload_message(payload.as_ref()), + "Epoch processing panicked" + ); + } + } + } epoch_opt = rx.recv() => { match epoch_opt { Some(epoch) => { @@ -423,16 +617,11 @@ impl EpochManager { "Received epoch from monitor" ); let self_clone = Arc::clone(&self); - tokio::spawn(async move { - if let Err(e) = self_clone.process_epoch(epoch).await { - error!( - event = "epoch_processing_failed", - run_id = %self_clone.run_id, - epoch, - error = ?e, - "Error processing epoch" - ); - } + epoch_tasks.push(async move { + let result = std::panic::AssertUnwindSafe(self_clone.process_epoch(epoch)) + .catch_unwind() + .await; + (epoch, result) }); } None => { @@ -488,6 +677,8 @@ impl EpochManager { // Abort monitor_handle on exit monitor_handle.abort(); + self.shutdown_new_tree_workers(NEW_TREE_WORKER_SHUTDOWN_TIMEOUT) + .await; result } @@ -720,33 +911,87 @@ impl EpochManager { epoch = current_epoch, "Spawning task to process new tree in current epoch" ); - tokio::spawn(async move { + let (cancel_tx, mut cancel_rx) = oneshot::channel(); + let (completion_tx, completion_rx) = oneshot::channel(); + let thread_handle = std::thread::spawn(move || { let tree_pubkey = tree_schedule.tree_accounts.merkle_tree; - if let Err(e) = self_clone - .process_queue( - &epoch_info.epoch, - epoch_info.forester_epoch_pda.clone(), - tree_schedule, - tracker, - ) - .await - { - error!( - event = "new_tree_process_queue_failed", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - error = ?e, - "Error processing queue for new tree" - ); - } else { - info!( - event = "new_tree_process_queue_succeeded", - run_id = %self_clone.run_id, - tree = %tree_pubkey, - "Successfully processed new tree in current epoch" - ); + let run_id = self_clone.run_id.clone(); + let thread_run_id = run_id.clone(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(async move { + tokio::select! { + result = self_clone + .clone() + .process_queue( + epoch_info.epoch.clone(), + epoch_info.forester_epoch_pda.clone(), + tree_schedule, + tracker, + ) => { + if let Err(e) = result { + error!( + event = "new_tree_process_queue_failed", + run_id = %thread_run_id, + tree = %tree_pubkey, + error = ?e, + "Error processing queue for new tree" + ); + } else { + info!( + event = "new_tree_process_queue_succeeded", + run_id = %thread_run_id, + tree = %tree_pubkey, + "Successfully processed new tree in current epoch" + ); + } + } + _ = &mut cancel_rx => { + info!( + event = "new_tree_process_queue_cancelled", + run_id = %thread_run_id, + tree = %tree_pubkey, + "Cancellation requested for new tree worker" + ); + } + } + Ok::<(), anyhow::Error>(()) + }) + })); + let _ = completion_tx.send(()); + match result { + Ok(Ok(())) => {} + Ok(Err(error)) => { + error!( + event = "new_tree_runtime_build_failed", + run_id = %run_id, + tree = %tree_pubkey, + error = ?error, + "Failed to build background runtime for new tree processing" + ); + } + Err(payload) => { + error!( + event = "new_tree_processing_task_panicked", + run_id = %run_id, + tree = %tree_pubkey, + panic = %panic_payload_message(payload.as_ref()), + "New tree processing thread panicked" + ); + } } }); + self.register_new_tree_worker(NewTreeWorker { + tree: new_tree.merkle_tree, + epoch: current_epoch, + cancel: Some(cancel_tx), + completion: completion_rx, + thread_handle, + }) + .await; } Ok(None) => { debug!( @@ -1615,16 +1860,19 @@ impl EpochManager { .cloned() .collect(); + let max_parallel_tree_workers = max_parallel_tree_workers(trees_to_process.len()); info!( event = "active_work_cycle_started", run_id = %self.run_id, current_slot, active_phase_end, tree_count = trees_to_process.len(), + parallel_tree_worker_limit = max_parallel_tree_workers, "Starting active work cycle" ); let self_arc = Arc::new(self.clone()); + let worker_slots = Arc::new(Semaphore::new(max_parallel_tree_workers)); let registration_tracker = self .registration_trackers .entry(epoch_info.epoch.epoch) @@ -1636,13 +1884,15 @@ impl EpochManager { .value() .clone(); - let mut handles: Vec = Vec::with_capacity(trees_to_process.len()); + let runtime_handle = Handle::current(); + let mut tasks = Vec::with_capacity(trees_to_process.len()); for tree in trees_to_process { + let tree_pubkey = tree.tree_accounts.merkle_tree; debug!( event = "tree_processing_task_spawned", run_id = %self.run_id, - tree = %tree.tree_accounts.merkle_tree, + tree = %tree_pubkey, tree_type = ?tree.tree_accounts.tree_type, "Spawning tree processing task" ); @@ -1652,41 +1902,68 @@ impl EpochManager { let epoch_clone = epoch_info.epoch.clone(); let forester_epoch_pda = epoch_info.forester_epoch_pda.clone(); let tracker = registration_tracker.clone(); - - let handle = tokio::spawn(async move { - self_clone - .process_queue(&epoch_clone, forester_epoch_pda, tree, tracker) - .await + let worker_slots = worker_slots.clone(); + let runtime_handle = runtime_handle.clone(); + tasks.push(async move { + let permit = match worker_slots.acquire_owned().await { + Ok(permit) => permit, + Err(_) => { + return Ok(( + tree_pubkey, + Err(anyhow!("tree worker semaphore was closed unexpectedly")), + )); + } + }; + tokio::task::spawn_blocking(move || { + let _permit = permit; + let result = runtime_handle.block_on(self_clone.process_queue( + epoch_clone, + forester_epoch_pda, + tree, + tracker, + )); + (tree_pubkey, result) + }) + .await }); - - handles.push(handle); } - debug!("Waiting for {} tree processing tasks", handles.len()); - let results = join_all(handles).await; + debug!("Waiting for {} tree processing tasks", tasks.len()); + let results = join_all(tasks).await; let mut success_count = 0usize; let mut error_count = 0usize; let mut panic_count = 0usize; for result in results { match result { - Ok(Ok(())) => success_count += 1, - Ok(Err(e)) => { + Ok((_, Ok(()))) => success_count += 1, + Ok((tree_pubkey, Err(e))) => { error_count += 1; error!( event = "tree_processing_task_failed", run_id = %self.run_id, + tree = %tree_pubkey, error = ?e, "Error processing queue" ); } - Err(e) => { + Err(join_error) => { panic_count += 1; - error!( - event = "tree_processing_task_panicked", - run_id = %self.run_id, - error = ?e, - "Tree processing task panicked" - ); + if join_error.is_panic() { + let payload = join_error.into_panic(); + error!( + event = "tree_processing_task_join_panicked", + run_id = %self.run_id, + panic = %panic_payload_message(payload.as_ref()), + "Tree processing task panicked before completion" + ); + } else { + error!( + event = "tree_processing_task_join_failed", + run_id = %self.run_id, + error = ?join_error, + "Tree processing task failed to join" + ); + } } } } @@ -1713,15 +1990,9 @@ impl EpochManager { Ok(current_slot) } - #[instrument( - level = "debug", - skip(self, epoch_info, forester_epoch_pda, tree_schedule, registration_tracker), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree) - )] pub(crate) async fn process_queue( - &self, - epoch_info: &Epoch, + self: Arc, + epoch_info: Epoch, mut forester_epoch_pda: ForesterEpochPda, mut tree_schedule: TreeForesterSchedule, registration_tracker: Arc, @@ -1758,26 +2029,28 @@ impl EpochManager { if let Some((slot_idx, light_slot_details)) = next_slot_to_process { let result = match tree_type { TreeType::StateV1 | TreeType::AddressV1 | TreeType::Unknown => { - self.process_light_slot( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - ) - .await + self.clone() + .process_light_slot( + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.tree_accounts, + light_slot_details.clone(), + ) + .await } TreeType::StateV2 | TreeType::AddressV2 => { let consecutive_end = tree_schedule .get_consecutive_eligibility_end(slot_idx) .unwrap_or(light_slot_details.end_solana_slot); - self.process_light_slot_v2( - epoch_info, - &forester_epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - consecutive_end, - ) - .await + self.clone() + .process_light_slot_v2( + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.tree_accounts, + light_slot_details.clone(), + consecutive_end, + ) + .await } }; @@ -1817,23 +2090,30 @@ impl EpochManager { // where cached_weight is correct but schedule was never recomputed. if force_refinalize || last_weight_check.elapsed() >= WEIGHT_CHECK_INTERVAL { last_weight_check = Instant::now(); - if let Err(e) = self + match self + .clone() .maybe_refinalize( - epoch_info, - &mut forester_epoch_pda, - &mut tree_schedule, - ®istration_tracker, + epoch_info.clone(), + forester_epoch_pda.clone(), + tree_schedule.clone(), + registration_tracker.clone(), force_refinalize, ) .await { - warn!( - event = "refinalize_check_failed", - run_id = %self.run_id, - forced = force_refinalize, - error = ?e, - "Failed to check/perform re-finalization" - ); + Ok((updated_pda, updated_schedule)) => { + forester_epoch_pda = updated_pda; + tree_schedule = updated_schedule; + } + Err(e) => { + warn!( + event = "refinalize_check_failed", + run_id = %self.run_id, + forced = force_refinalize, + error = ?e, + "Failed to check/perform re-finalization" + ); + } } } } else { @@ -1866,13 +2146,13 @@ impl EpochManager { /// When `force` is true (e.g. after a ForesterNotEligible error), skips /// the weight-change check and unconditionally refreshes the schedule. async fn maybe_refinalize( - &self, - epoch_info: &Epoch, - forester_epoch_pda: &mut ForesterEpochPda, - tree_schedule: &mut TreeForesterSchedule, - registration_tracker: &RegistrationTracker, + self: Arc, + epoch_info: Epoch, + forester_epoch_pda: ForesterEpochPda, + tree_schedule: TreeForesterSchedule, + registration_tracker: Arc, force: bool, - ) -> Result<()> { + ) -> Result<(ForesterEpochPda, TreeForesterSchedule)> { let mut rpc = self.rpc_pool.get_connection().await?; let epoch_pda_address = get_epoch_pda_address(epoch_info.epoch); let on_chain_epoch_pda: EpochPda = rpc @@ -1885,7 +2165,7 @@ impl EpochManager { let weight_changed = on_chain_weight != cached_weight; if !weight_changed && !force { - return Ok(()); + return Ok((forester_epoch_pda, tree_schedule)); } if weight_changed { @@ -1921,7 +2201,7 @@ impl EpochManager { "Skipping re-finalization because not enough active-phase time remains for confirmation" ); registration_tracker.complete_refinalize(cached_weight); - return Ok(()); + return Ok((forester_epoch_pda, tree_schedule)); }; let payer = self.config.payer_keypair.pubkey(); let signers = [&self.config.payer_keypair]; @@ -1999,33 +2279,24 @@ impl EpochManager { &refreshed_epoch_pda, )?; - *forester_epoch_pda = updated_pda; - *tree_schedule = new_schedule; - info!( event = "schedule_recomputed_after_refinalize", run_id = %self.run_id, epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree, - new_eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(), + tree = %new_schedule.tree_accounts.merkle_tree, + new_eligible_slots = new_schedule.slots.iter().filter(|s| s.is_some()).count(), "Recomputed schedule after re-finalization" ); - Ok(()) + Ok((updated_pda, new_schedule)) } - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_accounts.merkle_tree) - )] async fn process_light_slot( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, ) -> std::result::Result<(), ForesterError> { debug!( event = "light_slot_processing_started", @@ -2070,26 +2341,24 @@ impl EpochManager { break 'inner_processing_loop; } - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.queue, - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &tree_accounts.queue, + epoch_info.epoch, + &epoch_info, + )? { break 'inner_processing_loop; } let processing_start_time = Instant::now(); let items_processed_this_iteration = match self + .clone() .dispatch_tree_processing( - epoch_info, - epoch_pda, + epoch_info.clone(), + epoch_pda.clone(), tree_accounts, - forester_slot_details, + forester_slot_details.clone(), forester_slot_details.end_solana_slot, estimated_slot, ) @@ -2158,17 +2427,12 @@ impl EpochManager { Ok(()) } - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end), - fields(tree = %tree_accounts.merkle_tree) - )] async fn process_light_slot_v2( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, ) -> std::result::Result<(), ForesterError> { debug!( @@ -2195,7 +2459,7 @@ impl EpochManager { // Try to send any cached proofs first let cached_send_start = Instant::now(); if let Some(items_sent) = self - .try_send_cached_proofs(epoch_info, tree_accounts, consecutive_eligibility_end) + .try_send_cached_proofs(&epoch_info, &tree_accounts, consecutive_eligibility_end) .await? { if items_sent > 0 { @@ -2242,27 +2506,25 @@ impl EpochManager { break 'inner_processing_loop; } - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &tree_accounts.merkle_tree, - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &tree_accounts.merkle_tree, + epoch_info.epoch, + &epoch_info, + )? { break 'inner_processing_loop; } // Process directly - the processor fetches queue data from the indexer let processing_start_time = Instant::now(); match self + .clone() .dispatch_tree_processing( - epoch_info, - epoch_pda, + epoch_info.clone(), + epoch_pda.clone(), tree_accounts, - forester_slot_details, + forester_slot_details.clone(), consecutive_eligibility_end, estimated_slot, ) @@ -2327,7 +2589,7 @@ impl EpochManager { Ok(()) } - async fn check_forester_eligibility( + fn check_forester_eligibility( &self, epoch_pda: &ForesterEpochPda, current_light_slot: u64, @@ -2389,16 +2651,17 @@ impl EpochManager { #[allow(clippy::too_many_arguments)] async fn dispatch_tree_processing( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, current_solana_slot: u64, ) -> std::result::Result { match tree_accounts.tree_type { TreeType::Unknown => self + .clone() .dispatch_compression( epoch_info, epoch_pda, @@ -2408,18 +2671,24 @@ impl EpochManager { .await .map_err(ForesterError::from), TreeType::StateV1 | TreeType::AddressV1 => { - self.process_v1( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - current_solana_slot, - ) - .await + self.clone() + .process_v1( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + current_solana_slot, + ) + .await } TreeType::StateV2 | TreeType::AddressV2 => { let result = self - .process_v2(epoch_info, tree_accounts, consecutive_eligibility_end) + .clone() + .process_v2( + epoch_info.clone(), + tree_accounts, + consecutive_eligibility_end, + ) .await?; // Accumulate processing metrics for this epoch self.add_processing_metrics(epoch_info.epoch, result.metrics) @@ -2430,10 +2699,10 @@ impl EpochManager { } async fn dispatch_compression( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + forester_slot_details: ForesterSlot, consecutive_eligibility_end: u64, ) -> Result { let current_slot = self.slot_tracker.estimated_current_slot(); @@ -2455,16 +2724,13 @@ impl EpochManager { let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + &epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + &epoch_info, + )? { debug!( "Skipping compression: forester not eligible for current light slot {}", current_light_slot @@ -2518,85 +2784,85 @@ impl EpochManager { // Create parallel compression futures use futures::stream::StreamExt; - // Collect chunks into owned vectors to avoid lifetime issues - let batches: Vec<(usize, Vec<_>)> = accounts - .chunks(config.batch_size) - .enumerate() - .map(|(idx, chunk)| (idx, chunk.to_vec())) - .collect(); - + let run_id = self.run_id.clone(); let slot_tracker = self.slot_tracker.clone(); // Shared cancellation flag - when set, all pending futures should skip processing let cancelled = Arc::new(AtomicBool::new(false)); - let compression_futures = batches.into_iter().map(|(batch_idx, batch)| { - let compressor = compressor.clone(); - let slot_tracker = slot_tracker.clone(); - let cancelled = cancelled.clone(); - async move { - // Check if already cancelled by another future - if cancelled.load(Ordering::Relaxed) { - debug!( - "Skipping compression batch {}/{}: cancelled", - batch_idx + 1, - num_batches - ); - return Err((batch_idx, batch.len(), Cancelled.into())); - } - - // Check forester is still eligible before processing this batch - let current_slot = slot_tracker.estimated_current_slot(); - if current_slot >= consecutive_eligibility_end { - // Signal cancellation to all other futures - cancelled.store(true, Ordering::Relaxed); - warn!( - event = "compression_ctoken_cancelled_not_eligible", - run_id = %self.run_id, - current_slot, - eligibility_end_slot = consecutive_eligibility_end, - "Cancelling compression because forester is no longer eligible" - ); - return Err(( - batch_idx, - batch.len(), - anyhow!("Forester no longer eligible"), - )); - } + let compression_futures = + accounts + .chunks(config.batch_size) + .enumerate() + .map(|(batch_idx, chunk)| { + let batch = chunk.to_vec(); + let compressor = compressor.clone(); + let run_id = run_id.clone(); + let slot_tracker = slot_tracker.clone(); + let cancelled = cancelled.clone(); + async move { + // Check if already cancelled by another future + if cancelled.load(Ordering::Relaxed) { + debug!( + "Skipping compression batch {}/{}: cancelled", + batch_idx + 1, + num_batches + ); + return Err((batch_idx, batch.len(), Cancelled.into())); + } - debug!( - "Processing compression batch {}/{} with {} accounts", - batch_idx + 1, - num_batches, - batch.len() - ); + // Check forester is still eligible before processing this batch + let current_slot = slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + // Signal cancellation to all other futures + cancelled.store(true, Ordering::Relaxed); + warn!( + event = "compression_ctoken_cancelled_not_eligible", + run_id = %run_id, + current_slot, + eligibility_end_slot = consecutive_eligibility_end, + "Cancelling compression because forester is no longer eligible" + ); + return Err(( + batch_idx, + batch.len(), + anyhow!("Forester no longer eligible"), + )); + } - match compressor - .compress_batch(&batch, registered_forester_pda) - .await - { - Ok(sig) => { debug!( - "Compression batch {}/{} succeeded: {}", + "Processing compression batch {}/{} with {} accounts", batch_idx + 1, num_batches, - sig + batch.len() ); - Ok((batch_idx, batch.len(), sig)) - } - Err(e) => { - error!( - event = "compression_ctoken_batch_failed", - run_id = %self.run_id, - batch = batch_idx + 1, - total_batches = num_batches, - error = ?e, - "Compression batch failed" - ); - Err((batch_idx, batch.len(), e)) + + match compressor + .compress_batch(&batch, registered_forester_pda) + .await + { + Ok(sig) => { + debug!( + "Compression batch {}/{} succeeded: {}", + batch_idx + 1, + num_batches, + sig + ); + Ok((batch_idx, batch.len(), sig)) + } + Err(e) => { + error!( + event = "compression_ctoken_batch_failed", + run_id = %run_id, + batch = batch_idx + 1, + total_batches = num_batches, + error = ?e, + "Compression batch failed" + ); + Err((batch_idx, batch.len(), e)) + } + } } - } - } - }); + }); // Execute batches in parallel with concurrency limit let results = futures::stream::iter(compression_futures) @@ -2644,7 +2910,7 @@ impl EpochManager { // Process PDA compression if configured let pda_compressed = self - .dispatch_pda_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .dispatch_pda_compression(&epoch_info, &epoch_pda, consecutive_eligibility_end) .await .unwrap_or_else(|e| { error!( @@ -2658,7 +2924,7 @@ impl EpochManager { // Process Mint compression let mint_compressed = self - .dispatch_mint_compression(epoch_info, epoch_pda, consecutive_eligibility_end) + .dispatch_mint_compression(&epoch_info, &epoch_pda, consecutive_eligibility_end) .await .unwrap_or_else(|e| { error!( @@ -2943,16 +3209,13 @@ impl EpochManager { let current_light_slot = current_slot.saturating_sub(epoch_info.phases.active.start) / epoch_pda.protocol_config.slot_length; - if !self - .check_forester_eligibility( - epoch_pda, - current_light_slot, - &Pubkey::default(), - epoch_info.epoch, - epoch_info, - ) - .await? - { + if !self.check_forester_eligibility( + epoch_pda, + current_light_slot, + &Pubkey::default(), + epoch_info.epoch, + epoch_info, + )? { debug!( "Skipping {} compression: forester not eligible for current light slot {}", label, current_light_slot @@ -2964,11 +3227,11 @@ impl EpochManager { } async fn process_v1( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - tree_accounts: &TreeAccounts, - forester_slot_details: &ForesterSlot, + self: Arc, + epoch_info: Epoch, + epoch_pda: ForesterEpochPda, + tree_accounts: TreeAccounts, + forester_slot_details: ForesterSlot, current_solana_slot: u64, ) -> std::result::Result { let slots_remaining = forester_slot_details @@ -3018,7 +3281,7 @@ impl EpochManager { &self.config.derivation_pubkey, self.rpc_pool.clone(), &batched_tx_config, - *tree_accounts, + tree_accounts, transaction_builder, ) .await?; @@ -3033,7 +3296,7 @@ impl EpochManager { ); } - match self.rollover_if_needed(tree_accounts).await { + match self.rollover_if_needed(&tree_accounts).await { Ok(_) => Ok(num_sent), Err(e) => { error!( @@ -3283,15 +3546,15 @@ impl EpochManager { } async fn process_v2( - &self, - epoch_info: &Epoch, - tree_accounts: &TreeAccounts, + self: Arc, + epoch_info: Epoch, + tree_accounts: TreeAccounts, consecutive_eligibility_end: u64, ) -> std::result::Result { match tree_accounts.tree_type { TreeType::StateV2 => { let processor = self - .get_or_create_state_processor(epoch_info, tree_accounts) + .get_or_create_state_processor(&epoch_info, &tree_accounts) .await?; let cache = self @@ -3358,7 +3621,7 @@ impl EpochManager { } TreeType::AddressV2 => { let processor = self - .get_or_create_address_processor(epoch_info, tree_accounts) + .get_or_create_address_processor(&epoch_info, &tree_accounts) .await?; let cache = self diff --git a/forester/src/metrics.rs b/forester/src/metrics.rs index a7be12dd83..ece7d6a577 100644 --- a/forester/src/metrics.rs +++ b/forester/src/metrics.rs @@ -440,21 +440,19 @@ pub async fn metrics_handler() -> Result { if let Err(e) = encoder.encode(®ISTRY.gather(), &mut buffer) { error!("could not encode custom metrics: {}", e); }; - let mut res = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let mut res = String::from_utf8(buffer).unwrap_or_else(|e| { error!("custom metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); let mut buffer = Vec::new(); if let Err(e) = encoder.encode(&prometheus::gather(), &mut buffer) { error!("could not encode prometheus metrics: {}", e); }; - let res_prometheus = String::from_utf8(buffer.clone()).unwrap_or_else(|e| { + let res_prometheus = String::from_utf8(buffer).unwrap_or_else(|e| { error!("prometheus metrics could not be from_utf8'd: {}", e); String::new() }); - buffer.clear(); res.push_str(&res_prometheus); Ok(res) diff --git a/forester/src/priority_fee.rs b/forester/src/priority_fee.rs index c788712e1c..0b9ee4049c 100644 --- a/forester/src/priority_fee.rs +++ b/forester/src/priority_fee.rs @@ -208,7 +208,7 @@ pub async fn request_priority_fee_estimate( .map_err(|error| PriorityFeeEstimateError::ClientBuild(error.clone()))?; let response = http_client - .post(url.clone()) + .post(url.as_str()) .header("Content-Type", "application/json") .json(&rpc_request) .send() diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index 1727ed108b..735cf06925 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -277,7 +277,7 @@ async fn e2e_test() { validator_args: vec![], })) .await; - spawn_prover().await; + spawn_prover().await.unwrap(); } let mut rpc = setup_rpc_connection(&env.protocol.forester).await; @@ -799,15 +799,22 @@ async fn setup_forester_pipeline( let (shutdown_bootstrap_sender, shutdown_bootstrap_receiver) = oneshot::channel(); let (work_report_sender, work_report_receiver) = mpsc::channel(100); - let service_handle = tokio::spawn(run_pipeline::( - Arc::from(config.clone()), - None, - None, - shutdown_receiver, - Some(shutdown_compressible_receiver), - Some(shutdown_bootstrap_receiver), - work_report_sender, - )); + let config = Arc::new(config.clone()); + let service_handle = tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(run_pipeline::( + config, + None, + None, + shutdown_receiver, + Some(shutdown_compressible_receiver), + Some(shutdown_bootstrap_receiver), + work_report_sender, + )) + }); ( service_handle, diff --git a/forester/tests/legacy/batched_state_async_indexer_test.rs b/forester/tests/legacy/batched_state_async_indexer_test.rs index fe599a39a8..178927fea8 100644 --- a/forester/tests/legacy/batched_state_async_indexer_test.rs +++ b/forester/tests/legacy/batched_state_async_indexer_test.rs @@ -87,7 +87,7 @@ async fn test_state_indexer_async_batched() { validator_args: vec![], })) .await; - spawn_prover().await; + spawn_prover().await.unwrap(); let env = TestAccounts::get_local_test_validator_accounts(); let mut config = forester_config(); diff --git a/forester/tests/legacy/test_utils.rs b/forester/tests/legacy/test_utils.rs index d535665d71..76b146a160 100644 --- a/forester/tests/legacy/test_utils.rs +++ b/forester/tests/legacy/test_utils.rs @@ -26,7 +26,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await; + spawn_validator(config).await.unwrap(); } #[allow(dead_code)] diff --git a/forester/tests/test_batch_append_spent.rs b/forester/tests/test_batch_append_spent.rs index e53c2b64eb..fe3bb8bef2 100644 --- a/forester/tests/test_batch_append_spent.rs +++ b/forester/tests/test_batch_append_spent.rs @@ -328,15 +328,22 @@ async fn run_forester(config: &ForesterConfig, duration: Duration) { tokio::sync::broadcast::channel(1); let (work_report_sender, _) = mpsc::channel(100); - let service_handle = tokio::spawn(run_pipeline::( - Arc::from(config.clone()), - None, - None, - shutdown_receiver, - Some(shutdown_compressible_receiver), - None, // shutdown_bootstrap - work_report_sender, - )); + let config = Arc::new(config.clone()); + let service_handle = tokio::task::spawn_blocking(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build()?; + runtime.block_on(run_pipeline::( + config, + None, + None, + shutdown_receiver, + Some(shutdown_compressible_receiver), + None, // shutdown_bootstrap + work_report_sender, + )) + }); tokio::time::sleep(duration).await; diff --git a/forester/tests/test_indexer_interface.rs b/forester/tests/test_indexer_interface.rs index 6916cf0a3b..eaf56bf191 100644 --- a/forester/tests/test_indexer_interface.rs +++ b/forester/tests/test_indexer_interface.rs @@ -65,7 +65,8 @@ async fn test_indexer_interface_scenarios() { validator_args: vec![], use_surfpool: true, }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await diff --git a/forester/tests/test_utils.rs b/forester/tests/test_utils.rs index 4225503a19..3f50ad3695 100644 --- a/forester/tests/test_utils.rs +++ b/forester/tests/test_utils.rs @@ -36,7 +36,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await; + spawn_validator(config).await.unwrap(); } #[allow(dead_code)] From 1fd7a066378156109096ae3d425a25c8e7c9eb00 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Tue, 17 Mar 2026 09:05:24 +0000 Subject: [PATCH 12/14] fix: harden forester concurrency review findings --- forester/src/epoch_manager.rs | 51 +++++++++++++--- forester/tests/e2e_test.rs | 2 +- .../batched_state_async_indexer_test.rs | 9 +-- forester/tests/legacy/test_utils.rs | 2 +- forester/tests/test_batch_append_spent.rs | 6 +- forester/tests/test_indexer_interface.rs | 3 +- forester/tests/test_utils.rs | 2 +- prover/client/src/helpers.rs | 15 +++-- .../proof_types/batch_address_append/json.rs | 7 +++ .../proof_types/batch_append/proof_inputs.rs | 13 +++++ .../src/proof_types/combined/v2/json.rs | 14 +++-- .../src/proof_types/non_inclusion/v2/json.rs | 21 +++++-- prover/client/tests/init_merkle_tree.rs | 2 +- .../client/src/interface/initialize_config.rs | 26 +++------ sdk-libs/client/src/local_test_validator.rs | 58 ++++++++++--------- .../program-test/src/indexer/test_indexer.rs | 3 +- 16 files changed, 154 insertions(+), 80 deletions(-) diff --git a/forester/src/epoch_manager.rs b/forester/src/epoch_manager.rs index c1ee6544fc..9818c280ac 100644 --- a/forester/src/epoch_manager.rs +++ b/forester/src/epoch_manager.rs @@ -321,6 +321,8 @@ pub struct EpochManager { /// Per-epoch registration trackers to coordinate re-finalization when new foresters register mid-epoch registration_trackers: Arc>>, new_tree_workers: Arc>>, + shutdown_requested: Arc, + shutdown_notify: Arc, } impl Clone for EpochManager { @@ -352,6 +354,8 @@ impl Clone for EpochManager { run_id: self.run_id.clone(), registration_trackers: self.registration_trackers.clone(), new_tree_workers: self.new_tree_workers.clone(), + shutdown_requested: self.shutdown_requested.clone(), + shutdown_notify: self.shutdown_notify.clone(), } } } @@ -402,9 +406,17 @@ impl EpochManager { run_id: Arc::::from(run_id), registration_trackers: Arc::new(DashMap::new()), new_tree_workers: Arc::new(Mutex::new(Vec::new())), + shutdown_requested: Arc::new(AtomicBool::new(false)), + shutdown_notify: Arc::new(tokio::sync::Notify::new()), }) } + fn request_shutdown(&self) { + if !self.shutdown_requested.swap(true, Ordering::AcqRel) { + self.shutdown_notify.notify_waiters(); + } + } + fn join_new_tree_worker_with_run_id(run_id: Arc, worker: NewTreeWorker) { if let Err(payload) = worker.thread_handle.join() { error!( @@ -575,7 +587,24 @@ impl EpochManager { let mut epoch_tasks = FuturesUnordered::new(); let result = loop { + if self.shutdown_requested.load(Ordering::Acquire) { + info!( + event = "epoch_manager_shutdown_requested", + run_id = %self.run_id, + "Stopping EpochManager after shutdown request" + ); + break Ok(()); + } + tokio::select! { + _ = self.shutdown_notify.notified() => { + info!( + event = "epoch_manager_shutdown_requested", + run_id = %self.run_id, + "Stopping EpochManager after shutdown request" + ); + break Ok(()); + } Some((epoch, result)) = epoch_tasks.next(), if !epoch_tasks.is_empty() => { match result { Ok(Ok(())) => { @@ -2179,6 +2208,13 @@ impl EpochManager { ); if registration_tracker.try_claim_refinalize() { + let completion_weight = Arc::new(AtomicU64::new(cached_weight)); + let guard_tracker = registration_tracker.clone(); + let guard_weight = completion_weight.clone(); + let _refinalize_guard = scopeguard::guard((), move |_| { + guard_tracker.complete_refinalize(guard_weight.load(Ordering::Acquire)); + }); + // This task sends the finalize_registration tx let ix = create_finalize_registration_instruction( &self.config.payer_keypair.pubkey(), @@ -2200,7 +2236,6 @@ impl EpochManager { active_phase_end_slot = epoch_info.phases.active.end, "Skipping re-finalization because not enough active-phase time remains for confirmation" ); - registration_tracker.complete_refinalize(cached_weight); return Ok((forester_epoch_pda, tree_schedule)); }; let payer = self.config.payer_keypair.pubkey(); @@ -2241,11 +2276,9 @@ impl EpochManager { new_weight = post_finalize_weight, "Re-finalized registration on-chain" ); - registration_tracker.complete_refinalize(post_finalize_weight); + completion_weight.store(post_finalize_weight, Ordering::Release); } Err(e) => { - // Release the claim so a future check can retry - registration_tracker.complete_refinalize(cached_weight); return Err(e.into()); } } @@ -4772,16 +4805,20 @@ pub async fn run_service( retry_count + 1 ); + let run_future = epoch_manager.clone().run(); + tokio::pin!(run_future); + let result = tokio::select! { - result = epoch_manager.run() => result, - _ = shutdown => { + result = &mut run_future => result, + _ = &mut shutdown => { info!( event = "shutdown_received", run_id = %run_id_for_logs, phase = "service_run", "Received shutdown signal. Stopping the service." ); - Ok(()) + epoch_manager.request_shutdown(); + run_future.await } }; diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index 735cf06925..c3579a7876 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -277,7 +277,7 @@ async fn e2e_test() { validator_args: vec![], })) .await; - spawn_prover().await.unwrap(); + spawn_prover().await; } let mut rpc = setup_rpc_connection(&env.protocol.forester).await; diff --git a/forester/tests/legacy/batched_state_async_indexer_test.rs b/forester/tests/legacy/batched_state_async_indexer_test.rs index 178927fea8..cffd8c92a2 100644 --- a/forester/tests/legacy/batched_state_async_indexer_test.rs +++ b/forester/tests/legacy/batched_state_async_indexer_test.rs @@ -23,9 +23,7 @@ use light_compressed_account::{ use light_compressed_token::process_transfer::{ transfer_sdk::create_transfer_instruction, TokenTransferOutputData, }; -use light_token::compat::TokenDataWithMerkleContext; use light_program_test::accounts::test_accounts::TestAccounts; -use light_prover_client::prover::spawn_prover; use light_registry::{ protocol_config::state::{ProtocolConfig, ProtocolConfigPda}, utils::get_protocol_config_pda_address, @@ -34,6 +32,7 @@ use light_test_utils::{ conversions::sdk_to_program_token_data, spl::create_mint_helper_with_keypair, system_program::create_invoke_instruction, }; +use light_token::compat::TokenDataWithMerkleContext; use rand::{prelude::SliceRandom, rngs::StdRng, Rng, SeedableRng}; use serial_test::serial; use solana_program::{native_token::LAMPORTS_PER_SOL, pubkey::Pubkey}; @@ -87,7 +86,6 @@ async fn test_state_indexer_async_batched() { validator_args: vec![], })) .await; - spawn_prover().await.unwrap(); let env = TestAccounts::get_local_test_validator_accounts(); let mut config = forester_config(); @@ -306,10 +304,7 @@ async fn wait_for_slot(rpc: &mut LightClient, target_slot: u64) { return; } Err(e) => { - println!( - "warp_to_slot unavailable ({}), falling back to polling", - e - ); + println!("warp_to_slot unavailable ({}), falling back to polling", e); } } while rpc.get_slot().await.unwrap() < target_slot { diff --git a/forester/tests/legacy/test_utils.rs b/forester/tests/legacy/test_utils.rs index 76b146a160..d535665d71 100644 --- a/forester/tests/legacy/test_utils.rs +++ b/forester/tests/legacy/test_utils.rs @@ -26,7 +26,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await.unwrap(); + spawn_validator(config).await; } #[allow(dead_code)] diff --git a/forester/tests/test_batch_append_spent.rs b/forester/tests/test_batch_append_spent.rs index fe3bb8bef2..e2d93d39cf 100644 --- a/forester/tests/test_batch_append_spent.rs +++ b/forester/tests/test_batch_append_spent.rs @@ -349,7 +349,11 @@ async fn run_forester(config: &ForesterConfig, duration: Duration) { let _ = shutdown_sender.send(()); let _ = shutdown_compressible_sender.send(()); - let _ = timeout(Duration::from_secs(5), service_handle).await; + let join_result = timeout(Duration::from_secs(5), service_handle) + .await + .expect("forester service did not shut down within timeout"); + let service_result = join_result.expect("forester service task panicked"); + service_result.expect("run_pipeline::() failed"); } async fn get_onchain_root(rpc: &LightClient, tree_pubkey: Pubkey) -> (String, u64, u64) { diff --git a/forester/tests/test_indexer_interface.rs b/forester/tests/test_indexer_interface.rs index eaf56bf191..6916cf0a3b 100644 --- a/forester/tests/test_indexer_interface.rs +++ b/forester/tests/test_indexer_interface.rs @@ -65,8 +65,7 @@ async fn test_indexer_interface_scenarios() { validator_args: vec![], use_surfpool: true, }) - .await - .unwrap(); + .await; let mut rpc = LightClient::new(LightClientConfig::local()) .await diff --git a/forester/tests/test_utils.rs b/forester/tests/test_utils.rs index 3f50ad3695..4225503a19 100644 --- a/forester/tests/test_utils.rs +++ b/forester/tests/test_utils.rs @@ -36,7 +36,7 @@ pub async fn init(config: Option) { #[allow(dead_code)] pub async fn spawn_test_validator(config: Option) { let config = config.unwrap_or_default(); - spawn_validator(config).await.unwrap(); + spawn_validator(config).await; } #[allow(dead_code)] diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 9a20b8958e..f1d77b3bfa 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -2,7 +2,7 @@ use std::process::Command; use light_hasher::{Hasher, Poseidon}; use light_sparse_merkle_tree::changelog::ChangelogEntry; -use num_bigint::{BigInt, BigUint}; +use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; @@ -35,10 +35,17 @@ pub fn convert_endianness_128(bytes: &[u8]) -> Vec { .collect::>() } -pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box> { - let (_, bytes_be) = n.to_bytes_be(); +pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], ProverClientError> { + let (sign, bytes_be) = n.to_bytes_be(); + if sign == Sign::Minus { + return Err(ProverClientError::InvalidProofData( + "negative integers are not valid field elements".to_string(), + )); + } if bytes_be.len() > 32 { - Err("Number too large to fit in [u8; 32]")?; + return Err(ProverClientError::InvalidProofData( + "number too large to fit in [u8; 32]".to_string(), + )); } let mut array = [0; 32]; let bytes = &bytes_be[..bytes_be.len()]; diff --git a/prover/client/src/proof_types/batch_address_append/json.rs b/prover/client/src/proof_types/batch_address_append/json.rs index cd31a326e8..d27efd129a 100644 --- a/prover/client/src/proof_types/batch_address_append/json.rs +++ b/prover/client/src/proof_types/batch_address_append/json.rs @@ -19,6 +19,8 @@ pub struct BatchAddressAppendInputsJson { pub low_element_indices: Vec, #[serde(rename = "lowElementNextValues")] pub low_element_next_values: Vec, + #[serde(rename = "lowElementNextIndices")] + pub low_element_next_indices: Vec, #[serde(rename = "lowElementProofs")] pub low_element_proofs: Vec>, #[serde(rename = "newElementValues")] @@ -64,6 +66,11 @@ impl BatchAddressAppendInputsJson { .iter() .map(big_uint_to_string) .collect(), + low_element_next_indices: inputs + .low_element_next_indices + .iter() + .map(big_uint_to_string) + .collect(), low_element_proofs: inputs .low_element_proofs .iter() diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index 41a6dcfcd6..e72df989a9 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -128,6 +128,19 @@ pub fn get_batch_append_inputs( batch_size: u32, previous_changelogs: &[ChangelogEntry], ) -> Result<(BatchAppendsCircuitInputs, Vec>), ProverClientError> { + let batch_len = batch_size as usize; + for (name, len) in [ + ("old_leaves", old_leaves.len()), + ("leaves", leaves.len()), + ("merkle_proofs", merkle_proofs.len()), + ] { + if len != batch_len { + return Err(ProverClientError::GenericError(format!( + "invalid batch append inputs: {name} len {len} != batch_size {batch_len}" + ))); + } + } + let mut new_root = [0u8; 32]; let mut changelog: Vec> = Vec::new(); let mut circuit_merkle_proofs = Vec::with_capacity(batch_size as usize); diff --git a/prover/client/src/proof_types/combined/v2/json.rs b/prover/client/src/proof_types/combined/v2/json.rs index 322a5ee8ec..71de6b2ed3 100644 --- a/prover/client/src/proof_types/combined/v2/json.rs +++ b/prover/client/src/proof_types/combined/v2/json.rs @@ -2,6 +2,7 @@ use serde::Serialize; use crate::{ constants::{DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT}, + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{ circuit_type::CircuitType, @@ -29,21 +30,22 @@ pub struct CombinedJsonStruct { } impl CombinedJsonStruct { - pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Self { + pub fn from_combined_inputs(inputs: &CombinedProofInputs) -> Result { let inclusion_parameters = BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inputs.inclusion_parameters); - let non_inclusion_parameters = BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( - &inputs.non_inclusion_parameters, - ); + let non_inclusion_parameters = + BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( + &inputs.non_inclusion_parameters, + )?; - Self { + Ok(Self { circuit_type: CircuitType::Combined.to_string(), state_tree_height: DEFAULT_BATCH_STATE_TREE_HEIGHT, address_tree_height: DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, public_input_hash: big_int_to_string(&inputs.public_input_hash), inclusion: inclusion_parameters.inputs, non_inclusion: non_inclusion_parameters.inputs, - } + }) } #[allow(clippy::inherent_to_string)] diff --git a/prover/client/src/proof_types/non_inclusion/v2/json.rs b/prover/client/src/proof_types/non_inclusion/v2/json.rs index f6174e724d..9a556843af 100644 --- a/prover/client/src/proof_types/non_inclusion/v2/json.rs +++ b/prover/client/src/proof_types/non_inclusion/v2/json.rs @@ -2,6 +2,7 @@ use num_traits::ToPrimitive; use serde::Serialize; use crate::{ + errors::ProverClientError, helpers::{big_int_to_string, create_json_from_struct}, proof_types::{circuit_type::CircuitType, non_inclusion::v2::NonInclusionProofInputs}, }; @@ -24,7 +25,7 @@ pub struct NonInclusionJsonStruct { pub value: String, #[serde(rename(serialize = "pathIndex"))] - pub path_index: u32, + pub path_index: u64, #[serde(rename(serialize = "pathElements"))] pub path_elements: Vec, @@ -42,13 +43,23 @@ impl BatchNonInclusionJsonStruct { create_json_from_struct(&self) } - pub fn from_non_inclusion_proof_inputs(inputs: &NonInclusionProofInputs) -> Self { + pub fn from_non_inclusion_proof_inputs( + inputs: &NonInclusionProofInputs, + ) -> Result { let mut proof_inputs: Vec = Vec::new(); for input in inputs.inputs.iter() { let prof_input = NonInclusionJsonStruct { root: big_int_to_string(&input.root), value: big_int_to_string(&input.value), - path_index: input.index_hashed_indexed_element_leaf.to_u32().unwrap(), + path_index: input + .index_hashed_indexed_element_leaf + .to_u64() + .ok_or_else(|| { + ProverClientError::IntegerConversion(format!( + "failed to convert path index {} to u64", + input.index_hashed_indexed_element_leaf + )) + })?, path_elements: input .merkle_proof_hashed_indexed_element_leaf .iter() @@ -60,11 +71,11 @@ impl BatchNonInclusionJsonStruct { proof_inputs.push(prof_input); } - Self { + Ok(Self { circuit_type: CircuitType::NonInclusion.to_string(), address_tree_height: 40, public_input_hash: big_int_to_string(&inputs.public_input_hash), inputs: proof_inputs, - } + }) } } diff --git a/prover/client/tests/init_merkle_tree.rs b/prover/client/tests/init_merkle_tree.rs index 3bb5584cd3..1cba92ad6b 100644 --- a/prover/client/tests/init_merkle_tree.rs +++ b/prover/client/tests/init_merkle_tree.rs @@ -221,7 +221,7 @@ pub fn non_inclusion_new_with_public_inputs_v2( .collect(), path_index: merkle_inputs .index_hashed_indexed_element_leaf - .to_u32() + .to_u64() .unwrap(), leaf_lower_range_value: big_int_to_string(&merkle_inputs.leaf_lower_range_value), leaf_higher_range_value: big_int_to_string(&merkle_inputs.leaf_higher_range_value), diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 9fbeacfe89..8a5d6b35a9 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -1,9 +1,10 @@ //! Helper for initializing config with sensible defaults. #[cfg(feature = "anchor")] -use anchor_lang::{AnchorDeserialize, AnchorSerialize}; +use anchor_lang::AnchorSerialize; #[cfg(not(feature = "anchor"))] -use borsh::{BorshDeserialize as AnchorDeserialize, BorshSerialize as AnchorSerialize}; +use borsh::BorshSerialize as AnchorSerialize; +use light_account::InitializeLightConfigParams; use solana_instruction::{AccountMeta, Instruction}; use solana_pubkey::Pubkey; @@ -16,16 +17,6 @@ pub const ADDRESS_TREE_V2: Pubkey = /// Default write top-up value (5000 lamports). pub const DEFAULT_INIT_WRITE_TOP_UP: u32 = 5_000; -/// Instruction data format matching anchor-generated `initialize_compression_config`. -#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)] -pub struct InitializeCompressionConfigAnchorData { - pub write_top_up: u32, - pub rent_sponsor: Pubkey, - pub compression_authority: Pubkey, - pub rent_config: light_compressible::rent::RentConfig, - pub address_space: Vec, -} - /// Builder for `initialize_compression_config` instruction with sensible defaults. pub struct InitializeRentFreeConfig { program_id: Pubkey, @@ -109,12 +100,13 @@ impl InitializeRentFreeConfig { ), // system_program ]; - let instruction_data = InitializeCompressionConfigAnchorData { - write_top_up: self.write_top_up, - rent_sponsor: self.rent_sponsor, - compression_authority: self.compression_authority, + let instruction_data = InitializeLightConfigParams { + rent_sponsor: self.rent_sponsor.to_bytes(), + compression_authority: self.compression_authority.to_bytes(), rent_config: self.rent_config, - address_space: self.address_space, + write_top_up: self.write_top_up, + address_space: self.address_space.iter().map(|pubkey| pubkey.to_bytes()).collect(), + config_bump: self.config_bump, }; let serialized_data = instruction_data diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index b27daa6a25..bc6f0817da 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -58,53 +58,55 @@ impl Default for LightValidatorConfig { pub async fn spawn_validator(config: LightValidatorConfig) { if let Some(project_root) = get_project_root() { - let command = "cli/test_bin/run test-validator"; - let mut command = format!("{}/{}", project_root.trim(), command); + let project_root = project_root.trim_end_matches(['\n', '\r']); + let executable = format!("{}/cli/test_bin/run", project_root); + let mut args = vec!["test-validator".to_string()]; if !config.enable_indexer { - command.push_str(" --skip-indexer"); + args.push("--skip-indexer".to_string()); } if let Some(limit_ledger_size) = config.limit_ledger_size { - command.push_str(&format!(" --limit-ledger-size {}", limit_ledger_size)); + args.push("--limit-ledger-size".to_string()); + args.push(limit_ledger_size.to_string()); } for sbf_program in config.sbf_programs.iter() { - command.push_str(&format!( - " --sbf-program {} {}", - sbf_program.0, sbf_program.1 - )); + args.push("--sbf-program".to_string()); + args.push(sbf_program.0.clone()); + args.push(sbf_program.1.clone()); } for upgradeable_program in config.upgradeable_programs.iter() { - command.push_str(&format!( - " --upgradeable-program {} {} {}", - upgradeable_program.program_id, - upgradeable_program.program_path, - upgradeable_program.upgrade_authority - )); + args.push("--upgradeable-program".to_string()); + args.push(upgradeable_program.program_id.clone()); + args.push(upgradeable_program.program_path.clone()); + args.push(upgradeable_program.upgrade_authority.clone()); } if !config.enable_prover { - command.push_str(" --skip-prover"); + args.push("--skip-prover".to_string()); } if config.use_surfpool { - command.push_str(" --use-surfpool"); + args.push("--use-surfpool".to_string()); } for arg in config.validator_args.iter() { - command.push_str(&format!(" {}", arg)); + args.push(arg.clone()); } - println!("Starting validator with command: {}", command); + println!( + "Starting validator with command: {} {}", + executable, + args.join(" ") + ); if config.use_surfpool { // The CLI starts surfpool, prover, and photon, then exits once all // services are ready. Wait for it to finish so we know everything // is up before the test proceeds. - let mut child = Command::new("sh") - .arg("-c") - .arg(command) + let mut child = Command::new(&executable) + .args(&args) .stdin(Stdio::null()) .stdout(Stdio::inherit()) .stderr(Stdio::inherit()) @@ -113,17 +115,21 @@ pub async fn spawn_validator(config: LightValidatorConfig) { let status = child.wait().await.expect("Failed to wait for CLI process"); assert!(status.success(), "CLI exited with error: {}", status); } else { - let _child = Command::new("sh") - .arg("-c") - .arg(command) + let mut child = Command::new(&executable) + .args(&args) .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()) .spawn() .expect("Failed to start server process"); - // Intentionally detaching the spawned child; the caller only waits - // for the validator services to become available. tokio::time::sleep(tokio::time::Duration::from_secs(config.wait_time)).await; + if let Some(status) = child.try_wait().expect("Failed to poll validator process") { + assert!( + status.success(), + "Validator exited early with error: {}", + status + ); + } } } } diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 7618e045f3..f933912759 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2372,7 +2372,8 @@ impl TestIndexer { Some( BatchNonInclusionJsonStruct::from_non_inclusion_proof_inputs( &non_inclusion_proof_inputs, - ), + ) + .map_err(|error| IndexerError::CustomError(error.to_string()))?, ), None, ) From 0ae954ad1119b4049e1daa98c4b3d0f174c693fe Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Tue, 17 Mar 2026 09:28:21 +0000 Subject: [PATCH 13/14] format --- sdk-libs/client/src/interface/initialize_config.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdk-libs/client/src/interface/initialize_config.rs b/sdk-libs/client/src/interface/initialize_config.rs index 8a5d6b35a9..caa6f0036d 100644 --- a/sdk-libs/client/src/interface/initialize_config.rs +++ b/sdk-libs/client/src/interface/initialize_config.rs @@ -105,7 +105,11 @@ impl InitializeRentFreeConfig { compression_authority: self.compression_authority.to_bytes(), rent_config: self.rent_config, write_top_up: self.write_top_up, - address_space: self.address_space.iter().map(|pubkey| pubkey.to_bytes()).collect(), + address_space: self + .address_space + .iter() + .map(|pubkey| pubkey.to_bytes()) + .collect(), config_bump: self.config_bump, }; From bdb78501a3748dd2d0f64c9e475c9f0ca418a466 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:37:23 +0000 Subject: [PATCH 14/14] refactor: reduce clone pressure in compressible trackers --- .../src/compressible/ctoken/compressor.rs | 8 +++--- forester/src/compressible/ctoken/state.rs | 4 +-- forester/src/compressible/ctoken/types.rs | 4 ++- forester/src/compressible/mint/compressor.rs | 20 +++++++------- forester/src/compressible/mint/state.rs | 4 +-- forester/src/compressible/mint/types.rs | 4 ++- forester/src/compressible/pda/compressor.rs | 26 ++++++++++--------- forester/src/processor/v2/proof_cache.rs | 5 ++-- forester/tests/test_compressible_ctoken.rs | 7 +++-- forester/tests/test_compressible_mint.rs | 9 ++++--- 10 files changed, 52 insertions(+), 39 deletions(-) diff --git a/forester/src/compressible/ctoken/compressor.rs b/forester/src/compressible/ctoken/compressor.rs index 80e5ca4382..a74b3c335e 100644 --- a/forester/src/compressible/ctoken/compressor.rs +++ b/forester/src/compressible/ctoken/compressor.rs @@ -30,7 +30,7 @@ use crate::{ pub struct CTokenCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } @@ -39,7 +39,7 @@ impl Clone for CTokenCompressor { Self { rpc_pool: Arc::clone(&self.rpc_pool), tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + payer_keypair: Arc::clone(&self.payer_keypair), transaction_policy: self.transaction_policy, } } @@ -55,7 +55,7 @@ impl CTokenCompressor { Self { rpc_pool, tracker, - payer_keypair, + payer_keypair: Arc::new(payer_keypair), transaction_policy, } } @@ -253,7 +253,7 @@ impl CTokenCompressor { send_and_confirm_with_tracking( &mut *rpc, &[ix], - &self.payer_keypair, + self.payer_keypair.as_ref(), self.transaction_policy, &*self.tracker, &pubkeys, diff --git a/forester/src/compressible/ctoken/state.rs b/forester/src/compressible/ctoken/state.rs index 1c71e785f3..098919f8e3 100644 --- a/forester/src/compressible/ctoken/state.rs +++ b/forester/src/compressible/ctoken/state.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::AtomicU64; +use std::sync::{atomic::AtomicU64, Arc}; use borsh::BorshDeserialize; use dashmap::{DashMap, DashSet}; @@ -126,7 +126,7 @@ impl CTokenAccountTracker { let state = CTokenAccountState { pubkey, - account: ctoken, + account: Arc::new(ctoken), lamports, compressible_slot, is_ata, diff --git a/forester/src/compressible/ctoken/types.rs b/forester/src/compressible/ctoken/types.rs index 5eb9de4132..63e16282ea 100644 --- a/forester/src/compressible/ctoken/types.rs +++ b/forester/src/compressible/ctoken/types.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use light_token_interface::state::Token; use solana_sdk::pubkey::Pubkey; @@ -6,7 +8,7 @@ use crate::compressible::traits::CompressibleState; #[derive(Clone, Debug)] pub struct CTokenAccountState { pub pubkey: Pubkey, - pub account: Token, + pub account: Arc, pub lamports: u64, /// Ready to compress when current_slot > compressible_slot pub compressible_slot: u64, diff --git a/forester/src/compressible/mint/compressor.rs b/forester/src/compressible/mint/compressor.rs index 889a28d5e1..8f2769257e 100644 --- a/forester/src/compressible/mint/compressor.rs +++ b/forester/src/compressible/mint/compressor.rs @@ -30,7 +30,7 @@ use crate::{ pub struct MintCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } @@ -39,7 +39,7 @@ impl Clone for MintCompressor { Self { rpc_pool: Arc::clone(&self.rpc_pool), tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + payer_keypair: Arc::clone(&self.payer_keypair), transaction_policy: self.transaction_policy, } } @@ -55,7 +55,7 @@ impl MintCompressor { Self { rpc_pool, tracker, - payer_keypair, + payer_keypair: Arc::new(payer_keypair), transaction_policy, } } @@ -117,7 +117,7 @@ impl MintCompressor { send_and_confirm_with_tracking( &mut *rpc, &instructions, - &self.payer_keypair, + self.payer_keypair.as_ref(), self.transaction_policy, &*self.tracker, &pubkeys, @@ -160,7 +160,7 @@ impl MintCompressor { self.tracker.mark_pending(&all_pubkeys); // Create futures for each mint - let compression_futures = mint_states.iter().cloned().map(|mint_state| { + let compression_futures = mint_states.iter().map(|mint_state| { let compressor = self.clone(); let cancelled = cancelled.clone(); async move { @@ -168,18 +168,18 @@ impl MintCompressor { if cancelled.load(Ordering::Relaxed) { compressor.tracker.unmark_pending(&[mint_state.pubkey]); return CompressionOutcome::Failed { - state: mint_state, + state: mint_state.clone(), error: CompressionTaskError::Cancelled, }; } - match compressor.compress(&mint_state).await { + match compressor.compress(mint_state).await { Ok(sig) => CompressionOutcome::Compressed { signature: sig, - state: mint_state, + state: mint_state.clone(), }, Err(e) => CompressionOutcome::Failed { - state: mint_state, + state: mint_state.clone(), error: e.into(), }, } @@ -259,7 +259,7 @@ impl MintCompressor { let signature = send_and_confirm_with_tracking( &mut *rpc, &[ix], - &self.payer_keypair, + self.payer_keypair.as_ref(), self.transaction_policy, &*self.tracker, &tracked_pubkeys, diff --git a/forester/src/compressible/mint/state.rs b/forester/src/compressible/mint/state.rs index cca5a0b6fe..3bef0a3b9d 100644 --- a/forester/src/compressible/mint/state.rs +++ b/forester/src/compressible/mint/state.rs @@ -1,4 +1,4 @@ -use std::sync::atomic::AtomicU64; +use std::sync::{atomic::AtomicU64, Arc}; use borsh::BorshDeserialize; use dashmap::{DashMap, DashSet}; @@ -109,7 +109,7 @@ impl MintAccountTracker { pubkey, mint_seed, compressed_address, - mint, + mint: Arc::new(mint), lamports, compressible_slot, }; diff --git a/forester/src/compressible/mint/types.rs b/forester/src/compressible/mint/types.rs index 655e7532eb..6beb61e79c 100644 --- a/forester/src/compressible/mint/types.rs +++ b/forester/src/compressible/mint/types.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use light_token_interface::state::Mint; use solana_sdk::pubkey::Pubkey; @@ -8,7 +10,7 @@ pub struct MintAccountState { pub pubkey: Pubkey, pub mint_seed: Pubkey, pub compressed_address: [u8; 32], - pub mint: Mint, + pub mint: Arc, pub lamports: u64, /// Ready to compress when current_slot > compressible_slot pub compressible_slot: u64, diff --git a/forester/src/compressible/pda/compressor.rs b/forester/src/compressible/pda/compressor.rs index 941f0d3e53..8a04205e2d 100644 --- a/forester/src/compressible/pda/compressor.rs +++ b/forester/src/compressible/pda/compressor.rs @@ -55,7 +55,7 @@ pub struct CachedProgramConfig { pub struct PdaCompressor { rpc_pool: Arc>, tracker: Arc, - payer_keypair: Keypair, + payer_keypair: Arc, transaction_policy: TransactionPolicy, } @@ -64,7 +64,7 @@ impl Clone for PdaCompressor { Self { rpc_pool: Arc::clone(&self.rpc_pool), tracker: Arc::clone(&self.tracker), - payer_keypair: self.payer_keypair.insecure_clone(), + payer_keypair: Arc::clone(&self.payer_keypair), transaction_policy: self.transaction_policy, } } @@ -80,7 +80,7 @@ impl PdaCompressor { Self { rpc_pool, tracker, - payer_keypair, + payer_keypair: Arc::new(payer_keypair), transaction_policy, } } @@ -171,10 +171,12 @@ impl PdaCompressor { self.tracker.mark_pending(&all_pubkeys); // Create futures for each account - let compression_futures = account_states.iter().cloned().map(|account_state| { + let program_config = Arc::new(program_config.clone()); + let cached_config = Arc::new(cached_config.clone()); + let compression_futures = account_states.iter().map(|account_state| { let compressor = self.clone(); - let program_config = program_config.clone(); - let cached_config = cached_config.clone(); + let program_config = Arc::clone(&program_config); + let cached_config = Arc::clone(&cached_config); let cancelled = cancelled.clone(); async move { @@ -183,21 +185,21 @@ impl PdaCompressor { // Unmark since we won't process this account compressor.tracker.unmark_pending(&[account_state.pubkey]); return CompressionOutcome::Failed { - state: account_state, + state: account_state.clone(), error: CompressionTaskError::Cancelled, }; } match compressor - .compress(&account_state, &program_config, &cached_config) + .compress(account_state, &program_config, &cached_config) .await { Ok(sig) => CompressionOutcome::Compressed { signature: sig, - state: account_state, + state: account_state.clone(), }, Err(e) => CompressionOutcome::Failed { - state: account_state, + state: account_state.clone(), error: e.into(), }, } @@ -317,7 +319,7 @@ impl PdaCompressor { send_and_confirm_with_tracking( &mut *rpc, &[ix], - &self.payer_keypair, + self.payer_keypair.as_ref(), self.transaction_policy, &*self.tracker, &pubkeys, @@ -396,7 +398,7 @@ impl PdaCompressor { ); let payer_pubkey = self.payer_keypair.pubkey(); - let signers = [&self.payer_keypair]; + let signers = [self.payer_keypair.as_ref()]; let instructions = vec![ix]; let priority_fee_accounts = collect_priority_fee_accounts(payer_pubkey, &instructions); let signature = send_transaction_with_policy( diff --git a/forester/src/processor/v2/proof_cache.rs b/forester/src/processor/v2/proof_cache.rs index 123b4acf98..99812a853f 100644 --- a/forester/src/processor/v2/proof_cache.rs +++ b/forester/src/processor/v2/proof_cache.rs @@ -126,8 +126,9 @@ impl ProofCache { } } - self.proofs = self.warming_proofs.values().cloned().collect(); - self.warming_proofs.clear(); + self.proofs = std::mem::take(&mut self.warming_proofs) + .into_values() + .collect(); info!( "Cache warm-up complete for tree {}: {} proofs cached with root {:?}", diff --git a/forester/tests/test_compressible_ctoken.rs b/forester/tests/test_compressible_ctoken.rs index 569278286f..6183197dcc 100644 --- a/forester/tests/test_compressible_ctoken.rs +++ b/forester/tests/test_compressible_ctoken.rs @@ -320,7 +320,8 @@ async fn test_compressible_ctoken_compression() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await .expect("Failed to create LightClient"); @@ -403,6 +404,7 @@ async fn test_compressible_ctoken_compression() { account_type: ACCOUNT_TYPE_TOKEN_ACCOUNT, extensions: Some(vec![ExtensionStruct::Compressible(compressible_ext)]), } + .into() ); assert!(account_state.lamports > 0); let lamports = account_state.lamports; @@ -525,7 +527,8 @@ async fn test_compressible_ctoken_bootstrap() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await diff --git a/forester/tests/test_compressible_mint.rs b/forester/tests/test_compressible_mint.rs index 248db07251..30508acd05 100644 --- a/forester/tests/test_compressible_mint.rs +++ b/forester/tests/test_compressible_mint.rs @@ -145,7 +145,8 @@ async fn test_compressible_mint_bootstrap() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await @@ -287,7 +288,8 @@ async fn test_compressible_mint_compression() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await @@ -478,7 +480,8 @@ async fn test_compressible_mint_subscription() { use_surfpool: true, validator_args: vec![], }) - .await; + .await + .unwrap(); let mut rpc = LightClient::new(LightClientConfig::local()) .await