From ce73e21a449955bb85597045d149e13826066c98 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:32:06 +0000 Subject: [PATCH 1/9] feat: optimize address batch pipeline --- forester-utils/src/address_staging_tree.rs | 17 +- forester/src/processor/v2/helpers.rs | 63 ++++-- forester/src/processor/v2/strategy/address.rs | 79 +++---- program-tests/utils/src/e2e_test_env.rs | 103 ++++----- .../utils/src/mock_batched_forester.rs | 23 +- .../utils/src/test_batch_forester.rs | 71 +++--- prover/client/src/errors.rs | 3 + prover/client/src/helpers.rs | 10 +- .../batch_address_append/proof_inputs.rs | 84 ++++--- .../proof_types/batch_append/proof_inputs.rs | 7 +- .../proof_types/batch_update/proof_inputs.rs | 2 +- prover/client/tests/batch_address_append.rs | 23 +- sdk-libs/client/src/indexer/types/queue.rs | 208 ++++++++++++++++-- .../program-test/src/indexer/test_indexer.rs | 20 +- sparse-merkle-tree/src/indexed_changelog.rs | 6 +- sparse-merkle-tree/tests/indexed_changelog.rs | 14 +- 16 files changed, 472 insertions(+), 261 deletions(-) diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs index 786ddb6ac0..a6b1aa89bd 100644 --- a/forester-utils/src/address_staging_tree.rs +++ b/forester-utils/src/address_staging_tree.rs @@ -121,7 +121,7 @@ impl AddressStagingTree { low_element_next_values: &[[u8; 32]], low_element_indices: &[u64], low_element_next_indices: &[u64], - low_element_proofs: &[Vec<[u8; 32]>], + low_element_proofs: &[[[u8; 32]; HEIGHT]], leaves_hashchain: [u8; 32], zkp_batch_size: usize, epoch: u64, @@ -145,15 +145,12 @@ impl AddressStagingTree { let inputs = get_batch_address_append_circuit_inputs::( next_index, old_root, - low_element_values.to_vec(), - low_element_next_values.to_vec(), - low_element_indices.iter().map(|v| *v as usize).collect(), - low_element_next_indices - .iter() - .map(|v| *v as usize) - .collect(), - low_element_proofs.to_vec(), - addresses.to_vec(), + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + addresses, &mut self.sparse_tree, leaves_hashchain, zkp_batch_size, diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index ed135cb6a4..a0f3e3bb5b 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -9,6 +9,7 @@ use light_client::{ indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, rpc::Rpc, }; +use light_hasher::hash_chain::create_hash_chain_from_slice; use crate::processor::v2::{common::clamp_to_u16, BatchContext}; @@ -22,6 +23,17 @@ pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> Mu } } +#[derive(Debug, Clone)] +pub struct AddressBatchSnapshot { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, + pub low_element_proofs: Vec<[[u8; 32]; HEIGHT]>, + pub leaves_hashchain: [u8; 32], +} + pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { let rpc = context.rpc_pool.get_connection().await?; let mut account = rpc @@ -474,20 +486,52 @@ impl StreamingAddressQueue { } } - pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + pub fn get_batch_snapshot( + &self, + start: usize, + end: usize, + hashchain_idx: usize, + ) -> crate::Result>> { let available = self.wait_for_batch(end); if start >= available { - return None; + return Ok(None); } let actual_end = end.min(available); let data = lock_recover(&self.data, "streaming_address_queue.data"); - Some(BatchDataSlice { - addresses: data.addresses[start..actual_end].to_vec(), + + let addresses = data.addresses[start..actual_end].to_vec(); + if addresses.is_empty() { + return Err(anyhow!("Empty batch at start={}", start)); + } + + let leaves_hashchain = match data.leaves_hash_chains.get(hashchain_idx).copied() { + Some(hashchain) => hashchain, + None => { + tracing::debug!( + "Missing leaves_hash_chain for batch {} (available: {}), deriving from addresses", + hashchain_idx, + data.leaves_hash_chains.len() + ); + create_hash_chain_from_slice(&addresses).map_err(|error| { + anyhow!( + "Failed to derive leaves_hash_chain for batch {} from {} addresses: {}", + hashchain_idx, + addresses.len(), + error + ) + })? + } + }; + + Ok(Some(AddressBatchSnapshot { low_element_values: data.low_element_values[start..actual_end].to_vec(), low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - }) + low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + addresses, + leaves_hashchain, + })) } pub fn into_data(self) -> AddressQueueData { @@ -553,15 +597,6 @@ impl StreamingAddressQueue { } } -#[derive(Debug, Clone)] -pub struct BatchDataSlice { - pub addresses: Vec<[u8; 32]>, - pub low_element_values: Vec<[u8; 32]>, - pub low_element_next_values: Vec<[u8; 32]>, - pub low_element_indices: Vec, - pub low_element_next_indices: Vec, -} - pub async fn fetch_streaming_address_batches( context: &BatchContext, total_elements: u64, diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 06e94d5500..51ab05143a 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -14,11 +14,10 @@ use tracing::{debug, info, instrument}; use crate::processor::v2::{ batch_job_builder::BatchJobBuilder, - common::get_leaves_hashchain, errors::V2Error, helpers::{ fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, - lock_recover, StreamingAddressQueue, + AddressBatchSnapshot, StreamingAddressQueue, }, proof_worker::ProofInput, root_guard::{reconcile_alignment, AlignmentDecision}, @@ -267,9 +266,23 @@ impl BatchJobBuilder for AddressQueueData { let batch_end = start + zkp_batch_size_usize; - let batch_data = self - .streaming_queue - .get_batch_data(start, batch_end) + let streaming_queue = &self.streaming_queue; + let staging_tree = &mut self.staging_tree; + let hashchain_idx = start / zkp_batch_size_usize; + let AddressBatchSnapshot { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + low_element_proofs, + leaves_hashchain, + } = streaming_queue + .get_batch_snapshot::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + start, + batch_end, + hashchain_idx, + )? .ok_or_else(|| { anyhow!( "Batch data not available: start={}, end={}, available={}", @@ -278,31 +291,21 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.available_batches() * zkp_batch_size_usize ) })?; - - let addresses = &batch_data.addresses; let zkp_batch_size_actual = addresses.len(); - - if zkp_batch_size_actual == 0 { - return Err(anyhow!("Empty batch at start={}", start)); - } - - let low_element_values = &batch_data.low_element_values; - let low_element_next_values = &batch_data.low_element_next_values; - let low_element_indices = &batch_data.low_element_indices; - let low_element_next_indices = &batch_data.low_element_next_indices; - - let low_element_proofs: Vec> = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - (start..start + zkp_batch_size_actual) - .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) - .collect::, _>>()? - }; - - let hashchain_idx = start / zkp_batch_size_usize; - let leaves_hashchain = { - let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); - get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? - }; + let result = staging_tree + .process_batch( + &addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ) + .map_err(|err| map_address_staging_error(tree, err))?; let tree_batch = tree_next_index / zkp_batch_size_usize; let absolute_index = data_start + start; @@ -318,24 +321,6 @@ impl BatchJobBuilder for AddressQueueData { self.streaming_queue.is_complete() ); - let result = self.staging_tree.process_batch( - addresses, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - &low_element_proofs, - leaves_hashchain, - zkp_batch_size_actual, - epoch, - tree, - ); - - let result = match result { - Ok(r) => r, - Err(err) => return Err(map_address_staging_error(tree, err)), - }; - Ok(Some(( ProofInput::AddressAppend(result.circuit_inputs), result.new_root, diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index a16a7925a7..097ae1f9a9 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -73,7 +73,6 @@ use account_compression::{ use anchor_lang::{prelude::AccountMeta, AnchorSerialize, Discriminator}; use create_address_test_program::create_invoke_cpi_instruction; use forester_utils::{ - account_zero_copy::AccountZeroCopy, address_merkle_tree_config::{address_tree_ready_for_rollover, state_tree_ready_for_rollover}, forester_epoch::{Epoch, Forester, TreeAccounts}, utils::airdrop_lamports, @@ -194,6 +193,7 @@ use crate::{ }, test_batch_forester::{perform_batch_append, perform_batch_nullify}, test_forester::{empty_address_queue_test, nullify_compressed_accounts}, + AccountZeroCopy, }; pub struct User { @@ -748,70 +748,67 @@ where .with_address_queue(None, Some(batch.batch_size as u16)); let result = self .indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) + .get_queue_elements( + merkle_tree_pubkey.to_bytes(), + options, + None, + ) .await .unwrap(); - let addresses = result - .value - .address_queue - .map(|aq| aq.addresses) - .unwrap_or_default(); + let address_queue = result.value.address_queue.unwrap(); + let low_element_proofs = address_queue + .reconstruct_all_proofs::<{ + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize + }>() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = self - .indexer - .get_multiple_new_address_proofs( - merkle_tree_pubkey.to_bytes(), - addresses.clone(), - None, - ) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices - .push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices - .push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values - .push(non_inclusion_proof.low_address_next_value); - - low_element_proofs - .push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = self.indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); - let mut sparse_merkle_tree = SparseMerkleTree::::new(<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items).unwrap(), start_index); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + subtrees, + .. + } = address_queue; + let mut sparse_merkle_tree = SparseMerkleTree::< + Poseidon, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >::new( + subtrees.as_slice().try_into().unwrap(), + start_index, + ); - let mut changelog: Vec> = Vec::new(); - let mut indexed_changelog: Vec> = Vec::new(); + let mut changelog: Vec< + ChangelogEntry<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>, + > = Vec::new(); + let mut indexed_changelog: Vec< + IndexedChangelogEntry< + usize, + { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, + >, + > = Vec::new(); let inputs = get_batch_address_append_circuit_inputs::< { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, >( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, @@ -834,9 +831,13 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| RpcError::CustomError(error.to_string())) + .unwrap(); + let (proof_a, proof_b, proof_c) = + proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 4458aa03b3..2a93f772ba 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -260,7 +260,7 @@ impl MockBatchedAddressForester { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; HEIGHT]> = Vec::new(); for new_element_value in &new_element_values { let non_inclusion_proof = self .merkle_tree @@ -270,7 +270,14 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + let proof = non_inclusion_proof.merkle_proof.as_slice().try_into().map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; + low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); let mut merkle_tree = match <[[u8; 32]; HEIGHT]>::try_from(subtrees) { @@ -287,12 +294,12 @@ impl MockBatchedAddressForester { let inputs = match get_batch_address_append_circuit_inputs::( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values.clone(), + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut merkle_tree, leaves_hashchain, zkp_batch_size as usize, diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 8e6909704f..0952a7fd76 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -165,7 +165,9 @@ pub async fn create_append_batch_ix_data( bundle.merkle_tree.root() ); let proof_client = ProofClient::local(); - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs) + .to_string() + ; match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => ( @@ -319,13 +321,13 @@ pub async fn get_batched_nullify_ix_data( }) } -use forester_utils::{ - account_zero_copy::AccountZeroCopy, instructions::create_account::create_account_instruction, -}; +use forester_utils::instructions::create_account::create_account_instruction; use light_client::indexer::{Indexer, QueueElementsV2Options}; use light_program_test::indexer::state_tree::StateMerkleTreeBundle; use light_sparse_merkle_tree::SparseMerkleTree; +use crate::AccountZeroCopy; + pub async fn assert_registry_created_batched_state_merkle_tree( rpc: &mut R, payer_pubkey: Pubkey, @@ -663,50 +665,33 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof() + .unwrap(); // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_slice(addresses.as_slice()).unwrap(); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = merkle_tree.next_index as usize; + let start_index = address_queue.start_index as usize; assert!( start_index >= 1, "start index should be greater than 2 else tree is not inited" ); let current_root = *merkle_tree.root_history.last().unwrap(); - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); - let non_inclusion_proofs = indexer - .get_multiple_new_address_proofs(merkle_tree_pubkey.to_bytes(), addresses.clone(), None) - .await - .unwrap(); - for non_inclusion_proof in &non_inclusion_proofs.value.items { - low_element_values.push(non_inclusion_proof.low_address_value); - low_element_indices.push(non_inclusion_proof.low_address_index as usize); - low_element_next_indices.push(non_inclusion_proof.low_address_next_index as usize); - low_element_next_values.push(non_inclusion_proof.low_address_next_value); - - low_element_proofs.push(non_inclusion_proof.low_address_proof.to_vec()); - } - - let subtrees = indexer - .get_subtrees(merkle_tree_pubkey.to_bytes(), None) - .await - .unwrap(); + assert_eq!(address_queue.initial_root, current_root); + let light_client::indexer::AddressQueueData { + addresses, + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + subtrees, + .. + } = address_queue; let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new( - <[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]>::try_from(subtrees.value.items) - .unwrap(), - start_index, - ); + >::new(subtrees.as_slice().try_into().unwrap(), start_index); let mut changelog: Vec> = Vec::new(); @@ -718,12 +703,12 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - addresses, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &addresses, &mut sparse_merkle_tree, leaves_hash_chain, batch.zkp_batch_size as usize, diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index 85c1bc8fbe..e095bf3579 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -37,6 +37,9 @@ pub enum ProverClientError { #[error("Invalid proof data: {0}")] InvalidProofData(String), + #[error("Integer conversion failed: {0}")] + IntegerConversion(String), + #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 6ea223e79f..98457479e2 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -6,6 +6,8 @@ use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; +use crate::errors::ProverClientError; + pub fn get_project_root() -> Option { let output = Command::new("git") .args(["rev-parse", "--show-toplevel"]) @@ -48,7 +50,7 @@ pub fn compute_root_from_merkle_proof( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], path_index: u32, -) -> ([u8; 32], ChangelogEntry) { +) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); let mut current_hash = leaf; @@ -56,14 +58,14 @@ pub fn compute_root_from_merkle_proof( for (level, path_element) in path_elements.iter().enumerate() { changelog_entry.path[level] = Some(current_hash); if current_index.is_multiple_of(2) { - current_hash = Poseidon::hashv(&[¤t_hash, path_element]).unwrap(); + current_hash = Poseidon::hashv(&[¤t_hash, path_element])?; } else { - current_hash = Poseidon::hashv(&[path_element, ¤t_hash]).unwrap(); + current_hash = Poseidon::hashv(&[path_element, ¤t_hash])?; } current_index /= 2; } - (current_hash, changelog_entry) + Ok((current_hash, changelog_entry)) } pub fn big_uint_to_string(big_uint: &BigUint) -> String { diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index f80e8d49e4..32408fdc02 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; use light_hasher::{ bigint::bigint_to_be_bytes_array, @@ -187,21 +187,28 @@ impl BatchAddressAppendInputs { pub fn get_batch_address_append_circuit_inputs( next_index: usize, current_root: [u8; 32], - low_element_values: Vec<[u8; 32]>, - low_element_next_values: Vec<[u8; 32]>, - low_element_indices: Vec, - low_element_next_indices: Vec, - low_element_proofs: Vec>, - new_element_values: Vec<[u8; 32]>, + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[impl Copy + TryInto + Debug], + low_element_next_indices: &[impl Copy + TryInto + Debug], + low_element_proofs: &[[[u8; 32]; HEIGHT]], + new_element_values: &[[u8; 32]], sparse_merkle_tree: &mut SparseMerkleTree, leaves_hashchain: [u8; 32], zkp_batch_size: usize, changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); - - let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + let new_element_values = &new_element_values[..zkp_batch_size]; + let mut new_root = [0u8; 32]; + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); + let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + + let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) })?; if computed_hashchain != leaves_hashchain { @@ -229,15 +236,6 @@ pub fn get_batch_address_append_circuit_inputs( next_index ); - let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = vec![]; - let mut new_element_circuit_merkle_proofs = vec![]; - - let mut patched_low_element_next_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_next_indices: Vec = Vec::new(); - let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); - let mut patched_low_element_indices: Vec = Vec::new(); - let mut patcher = ChangelogProofPatcher::new::(changelog); let is_first_batch = indexed_changelog.is_empty(); @@ -245,21 +243,33 @@ pub fn get_batch_address_append_circuit_inputs( for i in 0..new_element_values.len() { let mut changelog_index = 0; + let low_element_index = low_element_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element index {:?} does not fit into usize", + low_element_indices[i] + )) + })?; + let low_element_next_index = low_element_next_indices[i].try_into().map_err(|_| { + ProverClientError::IntegerConversion(format!( + "low element next index {:?} does not fit into usize", + low_element_next_indices[i] + )) + })?; let new_element_index = next_index + i; let mut low_element: IndexedElement = IndexedElement { - index: low_element_indices[i], + index: low_element_index, value: BigUint::from_bytes_be(&low_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; let mut new_element: IndexedElement = IndexedElement { index: new_element_index, value: BigUint::from_bytes_be(&new_element_values[i]), - next_index: low_element_next_indices[i], + next_index: low_element_next_index, }; - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof = low_element_proofs[i]; let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); patch_indexed_changelogs( 0, @@ -293,18 +303,10 @@ pub fn get_batch_address_append_circuit_inputs( next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; + let low_element_changelog_proof = low_element_proof; let intermediate_root = { - let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof - .clone() - .try_into() - .map_err(|v: Vec<[u8; 32]>| { - ProverClientError::ProofPatchFailed(format!( - "low element proof length mismatch: expected {}, got {}", - HEIGHT, - v.len() - )) - })?; + let mut low_element_proof_arr = low_element_changelog_proof; patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); let merkle_proof = low_element_proof_arr; @@ -321,7 +323,7 @@ pub fn get_batch_address_append_circuit_inputs( old_low_leaf_hash, &merkle_proof, low_element.index as u32, - ); + )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) .map_err(|e| { @@ -362,7 +364,7 @@ pub fn get_batch_address_append_circuit_inputs( new_low_leaf_hash, &merkle_proof, new_low_element.index as u32, - ); + )?; patcher.push_changelog_entry::(changelog, changelog_entry); low_element_circuit_merkle_proofs.push( @@ -376,13 +378,7 @@ pub fn get_batch_address_append_circuit_inputs( }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT] - .try_into() - .map_err(|_| { - ProverClientError::ProofPatchFailed( - "low_element_proof slice conversion failed".to_string(), - ) - })?, + proof: low_element_changelog_proof, changelog_index: indexed_changelog.len(), //change_log_index, }; @@ -409,7 +405,7 @@ pub fn get_batch_address_append_circuit_inputs( new_element_leaf_hash, &merkle_proof_array, current_index as u32, - ); + )?; if i == 0 && changelog.len() == 1 { if sparse_next_idx_before != current_index { @@ -436,7 +432,7 @@ pub fn get_batch_address_append_circuit_inputs( zero_hash, &merkle_proof_array, current_index as u32, - ); + )?; if root_with_zero != intermediate_root { tracing::error!( "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index ef0327ac1d..7dd578e599 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -187,8 +187,11 @@ pub fn get_batch_append_inputs( }; // Update the root based on the current proof and nullifier - let (updated_root, changelog_entry) = - compute_root_from_merkle_proof(final_leaf, &merkle_proof_array, start_index + i as u32); + let (updated_root, changelog_entry) = compute_root_from_merkle_proof( + final_leaf, + &merkle_proof_array, + start_index + i as u32, + )?; new_root = updated_root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 2136d01d10..7f8c08e0d1 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -175,7 +175,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index); + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 22f58d5362..ac73c3809e 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -45,7 +45,8 @@ async fn prove_batch_address_append() { let mut low_element_indices = Vec::new(); let mut low_element_next_indices = Vec::new(); let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec> = Vec::new(); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::new(); // Generate non-inclusion proofs for each element for new_element_value in &new_element_values { @@ -57,7 +58,13 @@ async fn prove_batch_address_append() { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push(non_inclusion_proof.merkle_proof.as_slice().to_vec()); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); } // Convert big integers to byte arrays @@ -87,12 +94,12 @@ async fn prove_batch_address_append() { get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - new_element_values, + &low_element_values, + &low_element_next_values, + &low_element_indices, + &low_element_next_indices, + &low_element_proofs, + &new_element_values, &mut sparse_merkle_tree, hash_chain, zkp_batch_size, diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 40e7cc0f6e..de97ca7739 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::super::IndexerError; #[derive(Debug, Clone, PartialEq, Default)] @@ -65,12 +67,10 @@ pub struct AddressQueueData { impl AddressQueueData { /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. - /// The tree_height is needed to know how many levels to traverse. - pub fn reconstruct_proof( + pub fn reconstruct_proof( &self, address_idx: usize, - tree_height: u8, - ) -> Result, IndexerError> { + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -81,10 +81,10 @@ impl AddressQueueData { ), } })?; - let mut proof = Vec::with_capacity(tree_height as usize); + let mut proof = [[0u8; 32]; HEIGHT]; let mut pos = leaf_index; - for level in 0..tree_height { + for (level, proof_element) in proof.iter_mut().enumerate() { let sibling_pos = if pos.is_multiple_of(2) { pos + 1 } else { @@ -114,30 +114,212 @@ impl AddressQueueData { self.node_hashes.len(), ), })?; - proof.push(*hash); + *proof_element = *hash; pos /= 2; } Ok(proof) } + /// Reconstruct a contiguous batch of proofs while reusing a single node lookup table. + pub fn reconstruct_proofs( + &self, + address_range: std::ops::Range, + ) -> Result, IndexerError> { + let node_lookup = self.build_node_lookup(); + let mut proofs = Vec::with_capacity(address_range.len()); + + for address_idx in address_range { + proofs.push(self.reconstruct_proof_with_lookup::(address_idx, &node_lookup)?); + } + + Ok(proofs) + } + /// Reconstruct all proofs for all addresses - pub fn reconstruct_all_proofs( + pub fn reconstruct_all_proofs( &self, - tree_height: u8, - ) -> Result>, IndexerError> { - (0..self.addresses.len()) - .map(|i| self.reconstruct_proof(i, tree_height)) + ) -> Result, IndexerError> { + self.reconstruct_proofs::(0..self.addresses.len()) + } + + fn build_node_lookup(&self) -> HashMap { + self.nodes + .iter() + .copied() + .enumerate() + .map(|(idx, node)| (node, idx)) .collect() } + fn reconstruct_proof_with_lookup( + &self, + address_idx: usize, + node_lookup: &HashMap, + ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "address_idx {} out of bounds for low_element_indices (len {})", + address_idx, + self.low_element_indices.len(), + ), + } + })?; + let mut proof = [[0u8; 32]; HEIGHT]; + let mut pos = leaf_index; + + for (level, proof_element) in proof.iter_mut().enumerate() { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let sibling_idx = Self::encode_node_index(level, sibling_pos); + let hash_idx = node_lookup.get(&sibling_idx).copied().ok_or_else(|| { + IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "Missing proof node at level {} position {} (encoded: {})", + level, sibling_pos, sibling_idx + ), + } + })?; + let hash = + self.node_hashes + .get(hash_idx) + .ok_or_else(|| IndexerError::MissingResult { + context: "reconstruct_proof".to_string(), + message: format!( + "node_hashes index {} out of bounds (len {})", + hash_idx, + self.node_hashes.len(), + ), + })?; + *proof_element = *hash; + pos /= 2; + } + + Ok(proof) + } + /// Encode node index: (level << 56) | position #[inline] - fn encode_node_index(level: u8, position: u64) -> u64 { + fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } } +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, hint::black_box, time::Instant}; + + use super::AddressQueueData; + + fn hash_from_node(node_index: u64) -> [u8; 32] { + let mut hash = [0u8; 32]; + hash[..8].copy_from_slice(&node_index.to_le_bytes()); + hash[8..16].copy_from_slice(&node_index.rotate_left(17).to_le_bytes()); + hash[16..24].copy_from_slice(&node_index.rotate_right(9).to_le_bytes()); + hash[24..32].copy_from_slice(&(node_index ^ 0xA5A5_A5A5_A5A5_A5A5).to_le_bytes()); + hash + } + + fn build_queue_data(num_addresses: usize) -> AddressQueueData { + let low_element_indices = (0..num_addresses) + .map(|i| (i as u64).saturating_mul(2)) + .collect::>(); + let mut nodes = BTreeMap::new(); + + for &leaf_index in &low_element_indices { + let mut pos = leaf_index; + for level in 0..HEIGHT { + let sibling_pos = if pos.is_multiple_of(2) { + pos + 1 + } else { + pos - 1 + }; + let node_index = ((level as u64) << 56) | sibling_pos; + nodes + .entry(node_index) + .or_insert_with(|| hash_from_node(node_index)); + pos /= 2; + } + } + + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes.into_iter().unzip(); + + AddressQueueData { + addresses: vec![[0u8; 32]; num_addresses], + low_element_values: vec![[1u8; 32]; num_addresses], + low_element_next_values: vec![[2u8; 32]; num_addresses], + low_element_indices, + low_element_next_indices: (0..num_addresses).map(|i| (i as u64) + 1).collect(), + nodes, + node_hashes, + initial_root: [9u8; 32], + leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], + subtrees: vec![[4u8; 32]; HEIGHT], + start_index: 0, + root_seq: 0, + } + } + + #[test] + fn batched_reconstruction_matches_individual_reconstruction() { + let queue = build_queue_data::<40>(128); + + let expected = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::<40>(i).unwrap()) + .collect::>(); + let actual = queue + .reconstruct_proofs::<40>(0..queue.addresses.len()) + .unwrap(); + + assert_eq!(actual, expected); + } + + #[test] + #[ignore = "profiling helper"] + fn profile_reconstruct_proofs_batch() { + const HEIGHT: usize = 40; + const NUM_ADDRESSES: usize = 2_048; + const ITERS: usize = 25; + + let queue = build_queue_data::(NUM_ADDRESSES); + + let baseline_start = Instant::now(); + for _ in 0..ITERS { + let proofs = (0..queue.addresses.len()) + .map(|i| queue.reconstruct_proof::(i).unwrap()) + .collect::>(); + black_box(proofs); + } + let baseline = baseline_start.elapsed(); + + let batched_start = Instant::now(); + for _ in 0..ITERS { + black_box( + queue + .reconstruct_proofs::(0..queue.addresses.len()) + .unwrap(), + ); + } + let batched = batched_start.elapsed(); + + println!( + "queue reconstruction profile: addresses={}, height={}, iters={}, individual={:?}, batched={:?}, speedup={:.2}x", + NUM_ADDRESSES, + HEIGHT, + ITERS, + baseline, + batched, + baseline.as_secs_f64() / batched.as_secs_f64(), + ); + } +} + /// V2 Queue Elements Result with deduplicated node data #[derive(Debug, Clone, PartialEq, Default)] pub struct QueueElementsResult { diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..cbaea17320 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2170,18 +2170,20 @@ impl TestIndexer { let inclusion_proof_inputs = InclusionProofInputs::new(inclusion_proofs.as_slice()).unwrap(); ( - Some(BatchInclusionJsonStruct::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inclusion_proof_inputs), + ), None, ) } else if height == STATE_MERKLE_TREE_HEIGHT as usize { let inclusion_proof_inputs = InclusionProofInputsLegacy(inclusion_proofs.as_slice()); ( None, - Some(BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - )), + Some( + BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + ), + ), ) } else { return Err(IndexerError::CustomError( @@ -2358,7 +2360,11 @@ impl TestIndexer { if let Some(payload) = payload { (indices, Vec::new(), payload.to_string()) } else { - (indices, Vec::new(), payload_legacy.unwrap().to_string()) + ( + indices, + Vec::new(), + payload_legacy.unwrap().to_string(), + ) } } (None, Some(addresses)) => { diff --git a/sparse-merkle-tree/src/indexed_changelog.rs b/sparse-merkle-tree/src/indexed_changelog.rs index bbd30e1ee6..7e6a26cff7 100644 --- a/sparse-merkle-tree/src/indexed_changelog.rs +++ b/sparse-merkle-tree/src/indexed_changelog.rs @@ -29,7 +29,7 @@ pub fn patch_indexed_changelogs( low_element: &mut IndexedElement, new_element: &mut IndexedElement, low_element_next_value: &mut BigUint, - low_leaf_proof: &mut Vec<[u8; 32]>, + low_leaf_proof: &mut [[u8; 32]; HEIGHT], ) -> Result<(), SparseMerkleTreeError> { // Tests are in program-tests/merkle-tree/tests/indexed_changelog.rs let next_indexed_changelog_indices: Vec = (*indexed_changelogs) @@ -69,7 +69,7 @@ pub fn patch_indexed_changelogs( // Patch the next value. *low_element_next_value = BigUint::from_bytes_be(&changelog_entry.element.next_value); // Patch the proof. - *low_leaf_proof = changelog_entry.proof.to_vec(); + *low_leaf_proof = changelog_entry.proof; } // If we found a new low element. @@ -82,7 +82,7 @@ pub fn patch_indexed_changelogs( next_index: new_low_element_changelog_entry.element.next_index, }; - *low_leaf_proof = new_low_element_changelog_entry.proof.to_vec(); + *low_leaf_proof = new_low_element_changelog_entry.proof; new_element.next_index = low_element.next_index; if new_low_element_changelog_index == indexed_changelogs.len() - 1 { return Ok(()); diff --git a/sparse-merkle-tree/tests/indexed_changelog.rs b/sparse-merkle-tree/tests/indexed_changelog.rs index 7d37142b46..59efda6fde 100644 --- a/sparse-merkle-tree/tests/indexed_changelog.rs +++ b/sparse-merkle-tree/tests/indexed_changelog.rs @@ -92,7 +92,8 @@ fn test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -114,7 +115,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -124,7 +125,7 @@ fn test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); println!("patched -------------------"); @@ -206,7 +207,8 @@ fn debug_test_indexed_changelog() { next_index: low_element_next_indices[i], }; println!("unpatched new_element: {:?}", new_element); - let mut low_element_proof = low_element_proofs[i].to_vec(); + let mut low_element_proof: [[u8; 32]; 8] = + low_element_proofs[i].as_slice().try_into().unwrap(); let mut low_element_next_value = BigUint::from_bytes_be(&low_element_next_values[i]); if i > 0 { @@ -228,7 +230,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), index: low_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); indexed_changelog.push(IndexedChangelogEntry { @@ -238,7 +240,7 @@ fn debug_test_indexed_changelog() { next_value: bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap(), index: new_element.index, }, - proof: low_element_proof.as_slice().to_vec().try_into().unwrap(), + proof: low_element_proof, changelog_index: indexed_changelog.len(), }); man_indexed_array.elements[low_element.index()] = low_element.clone(); From d0e469796f4dbd8e22703f5636d6dfa0c46b3e94 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 19:30:35 +0000 Subject: [PATCH 2/9] format --- .../utils/src/mock_batched_forester.rs | 18 ++++++++++------- .../utils/src/test_batch_forester.rs | 4 +--- .../program-test/src/indexer/test_indexer.rs | 20 +++++++------------ .../tests/integration_tests.rs | 2 +- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 2a93f772ba..0101b235aa 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -270,13 +270,17 @@ impl MockBatchedAddressForester { low_element_indices.push(non_inclusion_proof.leaf_index); low_element_next_indices.push(non_inclusion_proof.next_index); low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - let proof = non_inclusion_proof.merkle_proof.as_slice().try_into().map_err(|_| { - ProverClientError::InvalidProofData(format!( - "invalid low element proof length: expected {}, got {}", - HEIGHT, - non_inclusion_proof.merkle_proof.len() - )) - })?; + let proof = non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .map_err(|_| { + ProverClientError::InvalidProofData(format!( + "invalid low element proof length: expected {}, got {}", + HEIGHT, + non_inclusion_proof.merkle_proof.len() + )) + })?; low_element_proofs.push(proof); } let subtrees = self.merkle_tree.merkle_tree.get_subtrees(); diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 0952a7fd76..a28efa9ca3 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -165,9 +165,7 @@ pub async fn create_append_batch_ix_data( bundle.merkle_tree.root() ); let proof_client = ProofClient::local(); - let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs) - .to_string() - ; + let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); match proof_client.generate_proof(inputs_json).await { Ok(compressed_proof) => ( diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index cbaea17320..0b5b0583a3 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2170,20 +2170,18 @@ impl TestIndexer { let inclusion_proof_inputs = InclusionProofInputs::new(inclusion_proofs.as_slice()).unwrap(); ( - Some( - BatchInclusionJsonStruct::from_inclusion_proof_inputs(&inclusion_proof_inputs), - ), + Some(BatchInclusionJsonStruct::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + )), None, ) } else if height == STATE_MERKLE_TREE_HEIGHT as usize { let inclusion_proof_inputs = InclusionProofInputsLegacy(inclusion_proofs.as_slice()); ( None, - Some( - BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( - &inclusion_proof_inputs, - ), - ), + Some(BatchInclusionJsonStructLegacy::from_inclusion_proof_inputs( + &inclusion_proof_inputs, + )), ) } else { return Err(IndexerError::CustomError( @@ -2360,11 +2358,7 @@ impl TestIndexer { if let Some(payload) = payload { (indices, Vec::new(), payload.to_string()) } else { - ( - indices, - Vec::new(), - payload_legacy.unwrap().to_string(), - ) + (indices, Vec::new(), payload_legacy.unwrap().to_string()) } } (None, Some(addresses)) => { diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 9b40b900e5..2c3e82972a 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, + edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, D9EdgeMixedParams, }; From 4650c0ab33faa272a8809976fb8813585eefbf4f Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 19:42:28 +0000 Subject: [PATCH 3/9] feat: stabilize address batch pipeline --- Cargo.lock | 18 +++++++++--------- prover/client/src/constants.rs | 2 +- prover/client/src/prover.rs | 5 ++++- .../program-test/src/indexer/test_indexer.rs | 4 +++- .../tests/integration_tests.rs | 2 +- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e293fd2b8..f65f0ac59e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4946,9 +4946,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-derive" @@ -11003,30 +11003,30 @@ dependencies = [ [[package]] name = "time" -version = "0.3.47" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "itoa", "num-conv", "powerfmt", - "serde_core", + "serde", "time-core", "time-macros", ] [[package]] name = "time-core" -version = "0.1.8" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-macros" -version = "0.2.27" +version = "0.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" dependencies = [ "num-conv", "time-core", diff --git a/prover/client/src/constants.rs b/prover/client/src/constants.rs index 18a5c05a45..151bf87918 100644 --- a/prover/client/src/constants.rs +++ b/prover/client/src/constants.rs @@ -1,4 +1,4 @@ -pub const SERVER_ADDRESS: &str = "http://localhost:3001"; +pub const SERVER_ADDRESS: &str = "http://127.0.0.1:3001"; pub const HEALTH_CHECK: &str = "/health"; pub const PROVE_PATH: &str = "/prove"; diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 3bf1bab785..56ae20d98a 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -51,7 +51,10 @@ pub async fn spawn_prover() { } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = reqwest::Client::new(); + let client = match reqwest::Client::builder().no_proxy().build() { + Ok(client) => client, + Err(_) => return false, + }; let mut result = false; for _ in 0..retries { match client diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 0b5b0583a3..c51298b9cd 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,7 +726,9 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - start_index: start as u64, + // Consumers use start_index as the sparse tree's next insertion index, + // not the pagination offset used for queue slicing. + start_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 2c3e82972a..9b40b900e5 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, + edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, D9EdgeMixedParams, }; From ed44b772e86025fed20cf60372daf0f2e4fad9b3 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sat, 14 Mar 2026 16:32:39 +0000 Subject: [PATCH 4/9] feat: batch cold account loads in light client --- sdk-libs/client/src/indexer/photon_indexer.rs | 21 ++- sdk-libs/client/src/indexer/types/queue.rs | 29 +++- .../client/src/interface/load_accounts.rs | 125 +++++++++++++----- 3 files changed, 126 insertions(+), 49 deletions(-) diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 26d16ae235..5698719c8f 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -1142,17 +1142,16 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - let mut proof_vec = x.proof.clone(); - if proof_vec.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { return Err(IndexerError::InvalidParameters(format!( "Merkle proof length ({}) is less than canopy depth ({})", - proof_vec.len(), + x.proof.len(), STATE_MERKLE_TREE_CANOPY_DEPTH, ))); } - proof_vec.truncate(proof_vec.len() - STATE_MERKLE_TREE_CANOPY_DEPTH); + let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; - let proof = proof_vec + let proof = x.proof[..proof_len] .iter() .map(|s| Hash::from_base58(s)) .collect::, IndexerError>>() @@ -1703,15 +1702,13 @@ impl Indexer for PhotonIndexer { async fn get_subtrees( &self, - _merkle_tree_pubkey: [u8; 32], + merkle_tree_pubkey: [u8; 32], _config: Option, ) -> Result>, IndexerError> { - #[cfg(not(feature = "v2"))] - unimplemented!(); - #[cfg(feature = "v2")] - { - todo!(); - } + Err(IndexerError::NotImplemented(format!( + "PhotonIndexer::get_subtrees is not implemented for merkle tree {}", + solana_pubkey::Pubkey::new_from_array(merkle_tree_pubkey) + ))) } } diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index de97ca7739..3f52d72798 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -66,11 +66,14 @@ pub struct AddressQueueData { } impl AddressQueueData { + const ADDRESS_TREE_HEIGHT: usize = 40; + /// Reconstruct a merkle proof for a given low_element_index from the deduplicated nodes. pub fn reconstruct_proof( &self, address_idx: usize, ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -126,6 +129,7 @@ impl AddressQueueData { &self, address_range: std::ops::Range, ) -> Result, IndexerError> { + self.validate_proof_height::()?; let node_lookup = self.build_node_lookup(); let mut proofs = Vec::with_capacity(address_range.len()); @@ -140,16 +144,16 @@ impl AddressQueueData { pub fn reconstruct_all_proofs( &self, ) -> Result, IndexerError> { + self.validate_proof_height::()?; self.reconstruct_proofs::(0..self.addresses.len()) } fn build_node_lookup(&self) -> HashMap { - self.nodes - .iter() - .copied() - .enumerate() - .map(|(idx, node)| (node, idx)) - .collect() + let mut lookup = HashMap::with_capacity(self.nodes.len()); + for (idx, node) in self.nodes.iter().copied().enumerate() { + lookup.entry(node).or_insert(idx); + } + lookup } fn reconstruct_proof_with_lookup( @@ -157,6 +161,7 @@ impl AddressQueueData { address_idx: usize, node_lookup: &HashMap, ) -> Result<[[u8; 32]; HEIGHT], IndexerError> { + self.validate_proof_height::()?; let leaf_index = *self.low_element_indices.get(address_idx).ok_or_else(|| { IndexerError::MissingResult { context: "reconstruct_proof".to_string(), @@ -209,6 +214,18 @@ impl AddressQueueData { fn encode_node_index(level: usize, position: u64) -> u64 { ((level as u64) << 56) | position } + + fn validate_proof_height(&self) -> Result<(), IndexerError> { + if HEIGHT == Self::ADDRESS_TREE_HEIGHT { + return Ok(()); + } + + Err(IndexerError::InvalidParameters(format!( + "address queue proofs require HEIGHT={} but got HEIGHT={}", + Self::ADDRESS_TREE_HEIGHT, + HEIGHT + ))) + } } #[cfg(test)] diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index 061ad5074b..c70088dd40 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -53,6 +53,9 @@ pub enum LoadAccountsError { #[error("Cold PDA at index {index} (pubkey {pubkey}) missing data")] MissingPdaCompressed { index: usize, pubkey: Pubkey }, + #[error("Cold PDA (pubkey {pubkey}) missing data")] + MissingPdaCompressedData { pubkey: Pubkey }, + #[error("Cold ATA at index {index} (pubkey {pubkey}) missing data")] MissingAtaCompressed { index: usize, pubkey: Pubkey }, @@ -67,6 +70,7 @@ pub enum LoadAccountsError { } const MAX_ATAS_PER_IX: usize = 8; +const MAX_PDAS_PER_IX: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -113,14 +117,18 @@ where }) .collect(); - let pda_hashes = collect_pda_hashes(&cold_pdas)?; + let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let pda_hashes = pda_groups + .iter() + .map(|group| collect_pda_hashes(group)) + .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; let (pda_proofs, ata_proofs, mint_proofs) = futures::join!( - fetch_proofs(&pda_hashes, indexer), + fetch_proof_batches(&pda_hashes, indexer), fetch_proofs_batched(&ata_hashes, MAX_ATAS_PER_IX, indexer), - fetch_proofs(&mint_hashes, indexer), + fetch_individual_proofs(&mint_hashes, indexer), ); let pda_proofs = pda_proofs?; @@ -136,9 +144,9 @@ where // 2. DecompressAccountsIdempotent for all cold PDAs (including token PDAs). // Token PDAs are created on-chain via CPI inside DecompressVariant. - for (spec, proof) in cold_pdas.iter().zip(pda_proofs) { + for (group, proof) in pda_groups.into_iter().zip(pda_proofs) { out.push(build_pda_load( - &[spec], + &group, proof, fee_payer, compression_config, @@ -146,8 +154,7 @@ where } // 3. ATA loads (CreateAssociatedTokenAccount + Transfer2) - requires mint to exist - let ata_chunks: Vec<_> = cold_atas.chunks(MAX_ATAS_PER_IX).collect(); - for (chunk, proof) in ata_chunks.into_iter().zip(ata_proofs) { + for (chunk, proof) in cold_atas.chunks(MAX_ATAS_PER_IX).zip(ata_proofs) { out.extend(build_ata_load(chunk, proof, fee_payer)?); } @@ -195,23 +202,77 @@ fn collect_mint_hashes(ifaces: &[&AccountInterface]) -> Result, Lo .collect() } -async fn fetch_proofs( +/// Groups already-ordered PDA specs into contiguous runs of the same program id. +/// +/// This preserves input order rather than globally regrouping by program. Callers that +/// want maximal batching across interleaved program ids should sort before calling. +fn group_pda_specs<'a, V>( + specs: &[&'a PdaSpec], + max_per_group: usize, +) -> Vec>> { + assert!(max_per_group > 0, "max_per_group must be non-zero"); + if specs.is_empty() { + return Vec::new(); + } + + let mut groups = Vec::new(); + let mut current = Vec::with_capacity(max_per_group); + let mut current_program: Option = None; + + for spec in specs { + let program_id = spec.program_id(); + let should_split = current_program + .map(|existing| existing != program_id || current.len() >= max_per_group) + .unwrap_or(false); + + if should_split { + groups.push(current); + current = Vec::with_capacity(max_per_group); + } + + current_program = Some(program_id); + current.push(*spec); + } + + if !current.is_empty() { + groups.push(current); + } + + groups +} + +async fn fetch_individual_proofs( hashes: &[[u8; 32]], indexer: &I, ) -> Result, IndexerError> { if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len()); - for hash in hashes { - proofs.push( - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await? - .value, - ); + + futures::future::try_join_all(hashes.iter().map(|hash| async move { + indexer + .get_validity_proof(vec![*hash], vec![], None) + .await + .map(|response| response.value) + })) + .await +} + +async fn fetch_proof_batches( + hash_batches: &[Vec<[u8; 32]>], + indexer: &I, +) -> Result, IndexerError> { + if hash_batches.is_empty() { + return Ok(vec![]); } - Ok(proofs) + + futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { + indexer + .get_validity_proof(hashes.clone(), vec![], None) + .await + .map(|response| response.value) + })) + .await } async fn fetch_proofs_batched( @@ -222,16 +283,13 @@ async fn fetch_proofs_batched( if hashes.is_empty() { return Ok(vec![]); } - let mut proofs = Vec::with_capacity(hashes.len().div_ceil(batch_size)); - for chunk in hashes.chunks(batch_size) { - proofs.push( - indexer - .get_validity_proof(chunk.to_vec(), vec![], None) - .await? - .value, - ); - } - Ok(proofs) + + let hash_batches = hashes + .chunks(batch_size) + .map(|chunk| chunk.to_vec()) + .collect::>(); + + fetch_proof_batches(&hash_batches, indexer).await } fn build_pda_load( @@ -262,11 +320,16 @@ where let hot_addresses: Vec = specs.iter().map(|s| s.address()).collect(); let cold_accounts: Vec<(CompressedAccount, V)> = specs .iter() - .map(|s| { - let compressed = s.compressed().expect("cold spec must have data").clone(); - (compressed, s.variant.clone()) + .map(|s| -> Result<_, LoadAccountsError> { + let compressed = + s.compressed() + .cloned() + .ok_or(LoadAccountsError::MissingPdaCompressedData { + pubkey: s.address(), + })?; + Ok((compressed, s.variant.clone())) }) - .collect(); + .collect::, _>>()?; let program_id = specs.first().map(|s| s.program_id()).unwrap_or_default(); From f9a788c5ca751b1cf7679704b7e2ba6460b180e5 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Sun, 15 Mar 2026 14:42:54 +0000 Subject: [PATCH 5/9] fix: harden load batching and mixed decompression --- forester/src/processor/v2/helpers.rs | 29 ++- forester/src/processor/v2/processor.rs | 6 +- forester/src/processor/v2/proof_worker.rs | 18 +- forester/src/processor/v2/strategy/address.rs | 2 +- program-tests/utils/src/e2e_test_env.rs | 3 +- .../utils/src/mock_batched_forester.rs | 9 +- .../utils/src/test_batch_forester.rs | 8 +- prover/client/src/helpers.rs | 4 +- prover/client/src/proof_client.rs | 31 ++- .../batch_address_append/proof_inputs.rs | 41 ++-- .../proof_types/batch_append/proof_inputs.rs | 2 +- .../proof_types/batch_update/proof_inputs.rs | 2 +- prover/client/src/prover.rs | 200 +++++++++++++++--- prover/client/tests/batch_address_append.rs | 131 ++++++++---- sdk-libs/client/src/indexer/photon_indexer.rs | 19 +- sdk-libs/client/src/indexer/types/queue.rs | 21 ++ .../client/src/interface/load_accounts.rs | 50 +++-- .../program-test/src/indexer/test_indexer.rs | 5 +- .../interface/program/decompression/pda.rs | 11 +- .../program/decompression/processor.rs | 15 +- .../interface/program/decompression/token.rs | 5 +- 21 files changed, 452 insertions(+), 160 deletions(-) diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index a0f3e3bb5b..a9fa2e290d 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -493,12 +493,29 @@ impl StreamingAddressQueue { hashchain_idx: usize, ) -> crate::Result>> { let available = self.wait_for_batch(end); - if start >= available { + if available < end || start >= end { return Ok(None); } - let actual_end = end.min(available); + let actual_end = end; let data = lock_recover(&self.data, "streaming_address_queue.data"); + for (name, len) in [ + ("addresses", data.addresses.len()), + ("low_element_values", data.low_element_values.len()), + ("low_element_next_values", data.low_element_next_values.len()), + ("low_element_indices", data.low_element_indices.len()), + ("low_element_next_indices", data.low_element_next_indices.len()), + ] { + if len < actual_end { + return Err(anyhow!( + "incomplete batch data: {} len {} < required end {}", + name, + len, + actual_end + )); + } + } + let addresses = data.addresses[start..actual_end].to_vec(); if addresses.is_empty() { return Err(anyhow!("Empty batch at start={}", start)); @@ -528,7 +545,9 @@ impl StreamingAddressQueue { low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - low_element_proofs: data.reconstruct_proofs::(start..actual_end)?, + low_element_proofs: data.reconstruct_proofs::(start..actual_end).map_err( + |error| anyhow!("incomplete batch data: failed to reconstruct proofs: {error}"), + )?, addresses, leaves_hashchain, })) @@ -566,6 +585,10 @@ impl StreamingAddressQueue { lock_recover(&self.data, "streaming_address_queue.data").start_index } + pub fn tree_next_insertion_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").tree_next_insertion_index + } + pub fn subtrees(&self) -> Vec<[u8; 32]> { lock_recover(&self.data, "streaming_address_queue.data") .subtrees diff --git a/forester/src/processor/v2/processor.rs b/forester/src/processor/v2/processor.rs index 3de6dea860..372a800e0e 100644 --- a/forester/src/processor/v2/processor.rs +++ b/forester/src/processor/v2/processor.rs @@ -132,7 +132,7 @@ where } if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -532,7 +532,7 @@ where ((queue_size / self.zkp_batch_size) as usize).min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } @@ -561,7 +561,7 @@ where let max_batches = max_batches.min(self.context.max_batches_per_tree); if self.worker_pool.is_none() { - let job_tx = spawn_proof_workers(&self.context.prover_config); + let job_tx = spawn_proof_workers(&self.context.prover_config)?; self.worker_pool = Some(WorkerPool { job_tx }); } diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index b7afeacf0b..ded9fcedc8 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -132,27 +132,27 @@ struct ProofClients { } impl ProofClients { - fn new(config: &ProverConfig) -> Self { - Self { + fn new(config: &ProverConfig) -> crate::Result { + Ok(Self { append_client: ProofClient::with_config( config.append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, nullify_client: ProofClient::with_config( config.update_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), + )?, address_append_client: ProofClient::with_config( config.address_append_url.clone(), config.polling_interval, config.max_wait_time, config.api_key.clone(), - ), - } + )?, + }) } fn get_client(&self, input: &ProofInput) -> &ProofClient { @@ -164,11 +164,11 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> async_channel::Sender { +pub fn spawn_proof_workers(config: &ProverConfig) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); - let clients = Arc::new(ProofClients::new(config)); + let clients = Arc::new(ProofClients::new(config)?); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); - job_tx + Ok(job_tx) } async fn run_proof_pipeline( diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs index 51ab05143a..51236c389b 100644 --- a/forester/src/processor/v2/strategy/address.rs +++ b/forester/src/processor/v2/strategy/address.rs @@ -167,7 +167,7 @@ impl TreeStrategy for AddressTreeStrategy { } let initial_root = streaming_queue.initial_root(); - let start_index = streaming_queue.start_index(); + let start_index = streaming_queue.tree_next_insertion_index(); let subtrees_arr: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = subtrees.try_into().map_err(|v: Vec<[u8; 32]>| { diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 097ae1f9a9..6c9fdb5d5e 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -764,7 +764,8 @@ where // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = address_queue.start_index as usize; + let start_index = + address_queue.tree_next_insertion_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 0101b235aa..8ea51b7169 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -132,7 +132,8 @@ impl MockBatchedForester { assert_eq!(computed_new_root, self.merkle_tree.root()); - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_append_proof(circuit_inputs) .await { @@ -207,7 +208,8 @@ impl MockBatchedForester { batch_size, &[], )?; - let proof_result = ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = proof_client .generate_batch_update_proof(inputs) .await?; let new_root = self.merkle_tree.root(); @@ -318,7 +320,8 @@ impl MockBatchedAddressForester { ))); } }; - let proof_result = match ProofClient::local() + let proof_client = ProofClient::local()?; + let proof_result = match proof_client .generate_batch_address_append_proof(inputs) .await { diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index a28efa9ca3..8cec32757f 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -164,7 +164,7 @@ pub async fn create_append_batch_ix_data( bigint_to_be_bytes_array::<32>(&circuit_inputs.new_root.to_biguint().unwrap()).unwrap(), bundle.merkle_tree.root() ); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let inputs_json = BatchAppendInputsJson::from_inputs(&circuit_inputs).to_string(); match proof_client.generate_proof(inputs_json).await { @@ -293,7 +293,7 @@ pub async fn get_batched_nullify_ix_data( &[], ) .unwrap(); - let proof_client = ProofClient::local(); + let proof_client = ProofClient::local().unwrap(); let circuit_inputs_new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root.to_biguint().unwrap()).unwrap(); let inputs_json = update_inputs_string(&inputs); @@ -670,7 +670,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof= 1, "start index should be greater than 2 else tree is not inited" @@ -715,7 +715,7 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof(&inputs.new_root).unwrap(); let inputs_json = to_json(&inputs); diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 98457479e2..9a20b8958e 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -49,9 +49,9 @@ pub fn bigint_to_u8_32(n: &BigInt) -> Result<[u8; 32], Box( leaf: [u8; 32], path_elements: &[[u8; 32]; HEIGHT], - path_index: u32, + path_index: usize, ) -> Result<([u8; 32], ChangelogEntry), ProverClientError> { - let mut changelog_entry = ChangelogEntry::default_with_index(path_index as usize); + let mut changelog_entry = ChangelogEntry::default_with_index(path_index); let mut current_hash = leaf; let mut current_index = path_index; diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 1d557407bd..820c9fe07d 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -6,7 +6,7 @@ use tokio::time::sleep; use tracing::{debug, error, info, trace, warn}; use crate::{ - constants::PROVE_PATH, + constants::{PROVE_PATH, SERVER_ADDRESS}, errors::ProverClientError, proof::{ compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed, @@ -17,14 +17,13 @@ use crate::{ batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs}, batch_update::{update_inputs_string, BatchUpdateCircuitInputs}, }, + prover::build_http_client, }; const MAX_RETRIES: u32 = 10; const BASE_RETRY_DELAY_SECS: u64 = 1; const DEFAULT_POLLING_INTERVAL_MS: u64 = 100; const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600; -const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001"; - const INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS: u64 = 200; const INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS: u64 = 200; @@ -68,15 +67,15 @@ pub struct ProofClient { } impl ProofClient { - pub fn local() -> Self { - Self { - client: Client::new(), - server_address: DEFAULT_LOCAL_SERVER.to_string(), + pub fn local() -> Result { + Ok(Self { + client: build_http_client()?, + server_address: SERVER_ADDRESS.to_string(), polling_interval: Duration::from_millis(DEFAULT_POLLING_INTERVAL_MS), max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS), api_key: None, initial_poll_delay: Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS), - } + }) } #[allow(unused)] @@ -85,21 +84,21 @@ impl ProofClient { polling_interval: Duration, max_wait_time: Duration, api_key: Option, - ) -> Self { + ) -> Result { let initial_poll_delay = if api_key.is_some() { Duration::from_millis(INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS) } else { Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS) }; - Self { - client: Client::new(), + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } #[allow(unused)] @@ -109,15 +108,15 @@ impl ProofClient { max_wait_time: Duration, api_key: Option, initial_poll_delay: Duration, - ) -> Self { - Self { - client: Client::new(), + ) -> Result { + Ok(Self { + client: build_http_client()?, server_address, polling_interval, max_wait_time, api_key, initial_poll_delay, - } + }) } pub async fn submit_proof_async( diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index 32408fdc02..231fa17ac5 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -199,14 +199,31 @@ pub fn get_batch_address_append_circuit_inputs( changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - let new_element_values = &new_element_values[..zkp_batch_size]; + let batch_len = zkp_batch_size; + for (name, len) in [ + ("new_element_values", new_element_values.len()), + ("low_element_values", low_element_values.len()), + ("low_element_next_values", low_element_next_values.len()), + ("low_element_indices", low_element_indices.len()), + ("low_element_next_indices", low_element_next_indices.len()), + ("low_element_proofs", low_element_proofs.len()), + ] { + if len < batch_len { + return Err(ProverClientError::GenericError(format!( + "truncated batch from indexer: {} len {} < required batch size {}", + name, len, batch_len + ))); + } + } + + let new_element_values = &new_element_values[..batch_len]; let mut new_root = [0u8; 32]; - let mut low_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); - let mut new_element_circuit_merkle_proofs = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_next_values = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_next_indices = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_values = Vec::with_capacity(new_element_values.len()); - let mut patched_low_element_indices = Vec::with_capacity(new_element_values.len()); + let mut low_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut new_element_circuit_merkle_proofs = Vec::with_capacity(batch_len); + let mut patched_low_element_next_values = Vec::with_capacity(batch_len); + let mut patched_low_element_next_indices = Vec::with_capacity(batch_len); + let mut patched_low_element_values = Vec::with_capacity(batch_len); + let mut patched_low_element_indices = Vec::with_capacity(batch_len); let computed_hashchain = create_hash_chain_from_slice(new_element_values).map_err(|e| { ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) @@ -241,7 +258,7 @@ pub fn get_batch_address_append_circuit_inputs( let is_first_batch = indexed_changelog.is_empty(); let mut expected_root_for_low = current_root; - for i in 0..new_element_values.len() { + for i in 0..batch_len { let mut changelog_index = 0; let low_element_index = low_element_indices[i].try_into().map_err(|_| { ProverClientError::IntegerConversion(format!( @@ -322,7 +339,7 @@ pub fn get_batch_address_append_circuit_inputs( let (computed_root, _) = compute_root_from_merkle_proof::( old_low_leaf_hash, &merkle_proof, - low_element.index as u32, + low_element.index, )?; if computed_root != expected_root_for_low { let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) @@ -363,7 +380,7 @@ pub fn get_batch_address_append_circuit_inputs( compute_root_from_merkle_proof::( new_low_leaf_hash, &merkle_proof, - new_low_element.index as u32, + new_low_element.index, )?; patcher.push_changelog_entry::(changelog, changelog_entry); @@ -404,7 +421,7 @@ pub fn get_batch_address_append_circuit_inputs( let (updated_root, changelog_entry) = compute_root_from_merkle_proof( new_element_leaf_hash, &merkle_proof_array, - current_index as u32, + current_index, )?; if i == 0 && changelog.len() == 1 { @@ -431,7 +448,7 @@ pub fn get_batch_address_append_circuit_inputs( let (root_with_zero, _) = compute_root_from_merkle_proof::( zero_hash, &merkle_proof_array, - current_index as u32, + current_index, )?; if root_with_zero != intermediate_root { tracing::error!( diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index 7dd578e599..41a6dcfcd6 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -190,7 +190,7 @@ pub fn get_batch_append_inputs( let (updated_root, changelog_entry) = compute_root_from_merkle_proof( final_leaf, &merkle_proof_array, - start_index + i as u32, + start_index as usize + i, )?; new_root = updated_root; changelog.push(changelog_entry); diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 7f8c08e0d1..2ada02b92b 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -175,7 +175,7 @@ pub fn get_batch_update_inputs( index_bytes[28..].copy_from_slice(&(*index).to_be_bytes()); let nullifier = Poseidon::hashv(&[leaf, &index_bytes, &tx_hashes[i]]).unwrap(); let (root, changelog_entry) = - compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index)?; + compute_root_from_merkle_proof(nullifier, &merkle_proof_array, *index as usize)?; new_root = root; changelog.push(changelog_entry); circuit_merkle_proofs.push( diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 56ae20d98a..5b5cfacc77 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,19 +1,120 @@ use std::{ - process::Command, + io::{Read, Write}, + net::{TcpStream, ToSocketAddrs}, + process::{Command, Stdio}, sync::atomic::{AtomicBool, Ordering}, - thread::sleep, time::Duration, }; use tracing::info; +use tokio::time::sleep; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, + errors::ProverClientError, helpers::get_project_root, }; static IS_LOADING: AtomicBool = AtomicBool::new(false); +pub(crate) fn build_http_client() -> Result { + reqwest::Client::builder() + .no_proxy() + .build() + .map_err(|error| { + ProverClientError::GenericError(format!("failed to build HTTP client: {error}")) + }) +} + +fn health_check_once(timeout: Duration) -> bool { + if prover_listener_present() { + return true; + } + + let endpoint = SERVER_ADDRESS + .strip_prefix("http://") + .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) + .unwrap_or(SERVER_ADDRESS); + let addr = match endpoint.to_socket_addrs().ok().and_then(|mut addrs| addrs.next()) { + Some(addr) => addr, + None => return false, + }; + + let mut stream = match TcpStream::connect_timeout(&addr, timeout) { + Ok(stream) => stream, + Err(error) => { + tracing::debug!(?error, endpoint, "prover health TCP connect failed"); + return health_check_once_with_curl(timeout); + } + }; + + let _ = stream.set_read_timeout(Some(timeout)); + let _ = stream.set_write_timeout(Some(timeout)); + + let host = endpoint.split(':').next().unwrap_or("127.0.0.1"); + let request = + format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", HEALTH_CHECK, host); + if let Err(error) = stream.write_all(request.as_bytes()) { + tracing::debug!(?error, "failed to write prover health request"); + return health_check_once_with_curl(timeout); + } + + let mut response = [0u8; 512]; + let bytes_read = match stream.read(&mut response) { + Ok(bytes_read) => bytes_read, + Err(error) => { + tracing::debug!(?error, "failed to read prover health response"); + return health_check_once_with_curl(timeout); + } + }; + + if bytes_read == 0 { + return false; + } + + let response = std::str::from_utf8(&response[..bytes_read]).unwrap_or_default(); + response.contains("200 OK") + || response.contains("{\"status\":\"ok\"}") + || health_check_once_with_curl(timeout) +} + +fn prover_listener_present() -> bool { + let endpoint = SERVER_ADDRESS + .strip_prefix("http://") + .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) + .unwrap_or(SERVER_ADDRESS); + let port = endpoint.rsplit(':').next().unwrap_or("3001"); + + match Command::new("lsof") + .args(["-nP", &format!("-iTCP:{port}"), "-sTCP:LISTEN"]) + .output() + { + Ok(output) => output.status.success() && !output.stdout.is_empty(), + Err(error) => { + tracing::debug!(?error, "failed to execute lsof prover listener check"); + false + } + } +} + +fn health_check_once_with_curl(timeout: Duration) -> bool { + let timeout_secs = timeout.as_secs().max(1).to_string(); + let url = format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK); + match Command::new("curl") + .args(["-sS", "-m", timeout_secs.as_str(), url.as_str()]) + .output() + { + Ok(output) => { + output.status.success() + && String::from_utf8_lossy(&output.stdout).contains("{\"status\":\"ok\"}") + } + Err(error) => { + tracing::debug!(?error, "failed to execute curl prover health check"); + false + } + } +} + pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { let prover_path: &str = { @@ -28,48 +129,81 @@ pub async fn spawn_prover() { } }; - if !health_check(10, 1).await && !IS_LOADING.load(Ordering::Relaxed) { - IS_LOADING.store(true, Ordering::Relaxed); + if health_check(10, 1).await { + return; + } + + if IS_LOADING + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_err() + { + if health_check(120, 1).await { + return; + } + panic!("Failed to start prover, health check failed."); + } + + let spawn_result = async { + let mut command = Command::new(prover_path); + command.arg("start-prover").stdout(Stdio::piped()).stderr(Stdio::piped()); + let mut child = command.spawn().expect("Failed to start prover process"); + let mut child_exit_status = None; - let command = Command::new(prover_path) - .arg("start-prover") - .spawn() - .expect("Failed to start prover process"); + for _ in 0..120 { + if health_check(1, 1).await { + info!("Prover started successfully"); + return; + } - let _ = command.wait_with_output(); + if child_exit_status.is_none() { + match child.try_wait() { + Ok(Some(status)) => { + tracing::warn!( + ?status, + "prover launcher exited before health check succeeded; continuing to poll for detached prover" + ); + child_exit_status = Some(status); + } + Ok(None) => {} + Err(error) => { + tracing::error!(?error, "failed to poll prover child process"); + } + } + } - let health_result = health_check(120, 1).await; - if health_result { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); + sleep(Duration::from_secs(1)).await; } + + if let Some(status) = child_exit_status { + panic!( + "Failed to start prover, health check failed after launcher exited with status {status}." + ); + } + + panic!("Failed to start prover, health check failed."); } + .await; + + IS_LOADING.store(false, Ordering::Release); + spawn_result } else { panic!("Failed to find project root."); }; } pub async fn health_check(retries: usize, timeout: usize) -> bool { - let client = match reqwest::Client::builder().no_proxy().build() { - Ok(client) => client, - Err(_) => return false, - }; - let mut result = false; - for _ in 0..retries { - match client - .get(format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK)) - .send() - .await - { - Ok(_) => { - result = true; - break; - } - Err(_) => { - sleep(Duration::from_secs(timeout as u64)); - } + let timeout = Duration::from_secs(timeout as u64); + let retry_delay = timeout; + + for attempt in 0..retries { + if health_check_once(timeout) { + return true; + } + + if attempt + 1 < retries { + sleep(retry_delay).await; } } - result + + false } diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index ac73c3809e..8a02c363ff 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -26,53 +26,55 @@ async fn prove_batch_address_append() { spawn_prover().await; // Initialize test data - let mut new_element_values = vec![]; - let zkp_batch_size = 10; - for i in 1..zkp_batch_size + 1 { - new_element_values.push(num_bigint::ToBigUint::to_biguint(&i).unwrap()); - } + let total_batch_size = 10usize; + let warmup_batch_size = 1usize; + let prior_value = 999_u32.to_biguint().unwrap(); + let new_element_values = (1..=total_batch_size) + .map(|i| num_bigint::ToBigUint::to_biguint(&i).unwrap()) + .collect::>(); // Initialize indexing structures - let relayer_merkle_tree = + let mut relayer_merkle_tree = IndexedMerkleTree::::new(DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize, 0) .unwrap(); - let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; - let current_root = relayer_merkle_tree.root(); + let collect_non_inclusion_data = + |tree: &IndexedMerkleTree, values: &[BigUint]| { + let mut low_element_values = Vec::with_capacity(values.len()); + let mut low_element_indices = Vec::with_capacity(values.len()); + let mut low_element_next_indices = Vec::with_capacity(values.len()); + let mut low_element_next_values = Vec::with_capacity(values.len()); + let mut low_element_proofs: Vec< + [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize], + > = Vec::with_capacity(values.len()); - // Prepare proof components - let mut low_element_values = Vec::new(); - let mut low_element_indices = Vec::new(); - let mut low_element_next_indices = Vec::new(); - let mut low_element_next_values = Vec::new(); - let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = - Vec::new(); + for new_element_value in values { + let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); - // Generate non-inclusion proofs for each element - for new_element_value in &new_element_values { - let non_inclusion_proof = relayer_merkle_tree - .get_non_inclusion_proof(new_element_value) - .unwrap(); + low_element_values.push(non_inclusion_proof.leaf_lower_range_value); + low_element_indices.push(non_inclusion_proof.leaf_index); + low_element_next_indices.push(non_inclusion_proof.next_index); + low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); + } - low_element_values.push(non_inclusion_proof.leaf_lower_range_value); - low_element_indices.push(non_inclusion_proof.leaf_index); - low_element_next_indices.push(non_inclusion_proof.next_index); - low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push( - non_inclusion_proof - .merkle_proof - .as_slice() - .try_into() - .unwrap(), - ); - } + ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) + }; - // Convert big integers to byte arrays - let new_element_values = new_element_values - .iter() - .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) - .collect::>(); - let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let initial_start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let initial_root = relayer_merkle_tree.root(); let subtrees: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = relayer_merkle_tree .merkle_tree @@ -82,7 +84,7 @@ async fn prove_batch_address_append() { let mut sparse_merkle_tree = SparseMerkleTree::< Poseidon, { DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }, - >::new(subtrees, start_index); + >::new(subtrees, initial_start_index); let mut changelog: Vec> = Vec::new(); @@ -90,6 +92,55 @@ async fn prove_batch_address_append() { IndexedChangelogEntry, > = Vec::new(); + let warmup_values = vec![prior_value.clone()]; + let ( + warmup_low_element_values, + warmup_low_element_indices, + warmup_low_element_next_indices, + warmup_low_element_next_values, + warmup_low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, &warmup_values); + let warmup_values = warmup_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let warmup_hash_chain = create_hash_chain_from_slice(&warmup_values).unwrap(); + + get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( + initial_start_index, + initial_root, + &warmup_low_element_values, + &warmup_low_element_next_values, + &warmup_low_element_indices, + &warmup_low_element_next_indices, + &warmup_low_element_proofs, + &warmup_values, + &mut sparse_merkle_tree, + warmup_hash_chain, + warmup_batch_size, + &mut changelog, + &mut indexed_changelog, + ) + .unwrap(); + + relayer_merkle_tree.append(&prior_value).unwrap(); + + let remaining_values = &new_element_values[..]; + let ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) = collect_non_inclusion_data(&relayer_merkle_tree, remaining_values); + let new_element_values = remaining_values + .iter() + .map(|v| bigint_to_be_bytes_array::<32>(v).unwrap()) + .collect::>(); + let hash_chain = create_hash_chain_from_slice(&new_element_values).unwrap(); + let start_index = relayer_merkle_tree.merkle_tree.rightmost_index; + let current_root = relayer_merkle_tree.root(); + let inputs = get_batch_address_append_circuit_inputs::<{ DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize }>( start_index, @@ -102,7 +153,7 @@ async fn prove_batch_address_append() { &new_element_values, &mut sparse_merkle_tree, hash_chain, - zkp_batch_size, + total_batch_size, &mut changelog, &mut indexed_changelog, ) diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 5698719c8f..eb8890d6b7 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, time::Duration}; use async_trait::async_trait; use bs58; -use light_sdk_types::constants::STATE_MERKLE_TREE_CANOPY_DEPTH; +use light_sdk_types::constants::{STATE_MERKLE_TREE_CANOPY_DEPTH, STATE_MERKLE_TREE_HEIGHT}; use photon_api::apis::configuration::Configuration; use solana_pubkey::Pubkey; use tracing::{error, trace, warn}; @@ -1142,14 +1142,24 @@ impl Indexer for PhotonIndexer { .value .iter() .map(|x| { - if x.proof.len() < STATE_MERKLE_TREE_CANOPY_DEPTH { + let expected_siblings = + STATE_MERKLE_TREE_HEIGHT - STATE_MERKLE_TREE_CANOPY_DEPTH; + let expected_total = STATE_MERKLE_TREE_CANOPY_DEPTH + expected_siblings; + if x.proof.len() != expected_total { return Err(IndexerError::InvalidParameters(format!( - "Merkle proof length ({}) is less than canopy depth ({})", + "Merkle proof length ({}) does not match expected total proof length ({})", x.proof.len(), - STATE_MERKLE_TREE_CANOPY_DEPTH, + expected_total, ))); } let proof_len = x.proof.len() - STATE_MERKLE_TREE_CANOPY_DEPTH; + if proof_len != expected_siblings { + return Err(IndexerError::InvalidParameters(format!( + "Merkle proof sibling count ({}) does not match expected sibling count ({})", + proof_len, + expected_siblings, + ))); + } let proof = x.proof[..proof_len] .iter() @@ -1681,6 +1691,7 @@ impl Indexer for PhotonIndexer { .map(|h| super::base58::decode_base58_to_fixed_array(&h.0)) .collect::, _>>()?, start_index: aq.start_index, + tree_next_insertion_index: aq.start_index, root_seq: aq.root_seq, }) } else { diff --git a/sdk-libs/client/src/indexer/types/queue.rs b/sdk-libs/client/src/indexer/types/queue.rs index 3f52d72798..79fcb45cf1 100644 --- a/sdk-libs/client/src/indexer/types/queue.rs +++ b/sdk-libs/client/src/indexer/types/queue.rs @@ -61,7 +61,10 @@ pub struct AddressQueueData { pub initial_root: [u8; 32], pub leaves_hash_chains: Vec<[u8; 32]>, pub subtrees: Vec<[u8; 32]>, + /// Pagination offset for the returned queue slice. pub start_index: u64, + /// Sparse tree insertion point / next index used to initialize staging trees. + pub tree_next_insertion_index: u64, pub root_seq: u64, } @@ -130,6 +133,19 @@ impl AddressQueueData { address_range: std::ops::Range, ) -> Result, IndexerError> { self.validate_proof_height::()?; + let available = self.proof_count(); + if address_range.start > address_range.end { + return Err(IndexerError::InvalidParameters(format!( + "invalid address proof range {}..{}", + address_range.start, address_range.end + ))); + } + if address_range.end > available { + return Err(IndexerError::InvalidParameters(format!( + "address proof range {}..{} exceeds available proofs {}", + address_range.start, address_range.end, available + ))); + } let node_lookup = self.build_node_lookup(); let mut proofs = Vec::with_capacity(address_range.len()); @@ -156,6 +172,10 @@ impl AddressQueueData { lookup } + fn proof_count(&self) -> usize { + self.addresses.len().min(self.low_element_indices.len()) + } + fn reconstruct_proof_with_lookup( &self, address_idx: usize, @@ -279,6 +299,7 @@ mod tests { leaves_hash_chains: vec![[3u8; 32]; num_addresses.max(1)], subtrees: vec![[4u8; 32]; HEIGHT], start_index: 0, + tree_next_insertion_index: 0, root_seq: 0, } } diff --git a/sdk-libs/client/src/interface/load_accounts.rs b/sdk-libs/client/src/interface/load_accounts.rs index c70088dd40..0d564734a2 100644 --- a/sdk-libs/client/src/interface/load_accounts.rs +++ b/sdk-libs/client/src/interface/load_accounts.rs @@ -1,5 +1,6 @@ //! Load cold accounts API. +use futures::{stream, StreamExt, TryStreamExt}; use light_account::{derive_rent_sponsor_pda, Pack}; use light_compressed_account::{ compressed_account::PackedMerkleContext, instruction_data::compressed_proof::ValidityProof, @@ -71,6 +72,7 @@ pub enum LoadAccountsError { const MAX_ATAS_PER_IX: usize = 8; const MAX_PDAS_PER_IX: usize = 8; +const PROOF_FETCH_CONCURRENCY: usize = 8; /// Build load instructions for cold accounts. Returns empty vec if all hot. /// @@ -118,9 +120,14 @@ where .collect(); let pda_groups = group_pda_specs(&cold_pdas, MAX_PDAS_PER_IX); + let mut pda_offset = 0usize; let pda_hashes = pda_groups .iter() - .map(|group| collect_pda_hashes(group)) + .map(|group| { + let hashes = collect_pda_hashes(group, pda_offset)?; + pda_offset += group.len(); + Ok::<_, LoadAccountsError>(hashes) + }) .collect::, _>>()?; let ata_hashes = collect_ata_hashes(&cold_atas)?; let mint_hashes = collect_mint_hashes(&cold_mints)?; @@ -161,13 +168,16 @@ where Ok(out) } -fn collect_pda_hashes(specs: &[&PdaSpec]) -> Result, LoadAccountsError> { +fn collect_pda_hashes( + specs: &[&PdaSpec], + start_index: usize, +) -> Result, LoadAccountsError> { specs .iter() .enumerate() .map(|(i, s)| { s.hash().ok_or(LoadAccountsError::MissingPdaCompressed { - index: i, + index: start_index + i, pubkey: s.address(), }) }) @@ -249,13 +259,16 @@ async fn fetch_individual_proofs( return Ok(vec![]); } - futures::future::try_join_all(hashes.iter().map(|hash| async move { - indexer - .get_validity_proof(vec![*hash], vec![], None) - .await - .map(|response| response.value) - })) - .await + stream::iter(hashes.iter().copied()) + .map(|hash| async move { + indexer + .get_validity_proof(vec![hash], vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await } async fn fetch_proof_batches( @@ -266,13 +279,16 @@ async fn fetch_proof_batches( return Ok(vec![]); } - futures::future::try_join_all(hash_batches.iter().map(|hashes| async move { - indexer - .get_validity_proof(hashes.clone(), vec![], None) - .await - .map(|response| response.value) - })) - .await + stream::iter(hash_batches.iter().cloned()) + .map(|hashes| async move { + indexer + .get_validity_proof(hashes, vec![], None) + .await + .map(|response| response.value) + }) + .buffered(PROOF_FETCH_CONCURRENCY) + .try_collect() + .await } async fn fetch_proofs_batched( diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index c51298b9cd..e799d3f29e 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -726,9 +726,8 @@ impl Indexer for TestIndexer { initial_root: address_tree_bundle.root(), leaves_hash_chains: Vec::new(), subtrees: address_tree_bundle.get_subtrees(), - // Consumers use start_index as the sparse tree's next insertion index, - // not the pagination offset used for queue slicing. - start_index: address_tree_bundle.right_most_index() as u64, + start_index: start as u64, + tree_next_insertion_index: address_tree_bundle.right_most_index() as u64, root_seq: address_tree_bundle.sequence_number(), }) } else { diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs index 3e32ec6ef3..8819724f59 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs @@ -143,12 +143,21 @@ where let address = derive_address(&pda_key, &ctx.light_config.address_space[0], ctx.program_id); // 10. Build CompressedAccountInfo for CPI + // When PDA decompression is only the first phase of a later token Transfer2 execution, + // the stored input queue index must match that later execution basis, not the original + // packed proof basis. The mixed PDA+token flow uses `output_queue_index` for that. + let input_queue_index = if ctx.cpi_accounts.config().cpi_context { + output_queue_index + } else { + tree_info.queue_pubkey_index + }; + let input = InAccountInfo { data_hash: input_data_hash, lamports: 0, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: tree_info.queue_pubkey_index, + queue_pubkey_index: input_queue_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs index 2fc7cfc811..2fc1f037b8 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/processor.rs @@ -126,7 +126,7 @@ pub struct DecompressCtx<'a, AI: AccountInfoTrait + Clone> { #[cfg(feature = "token")] pub in_tlv: Option>>, #[cfg(feature = "token")] - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } // ============================================================================ @@ -296,7 +296,7 @@ pub struct DecompressAccountsBuilt<'a, AI: AccountInfoTrait + Clone> { pub cpi_context: bool, pub in_token_data: Vec, pub in_tlv: Option>>, - pub token_seeds: Vec>, + pub token_seeds: Vec>>, } /// Validates accounts, dispatches all variants, and collects CPI inputs for @@ -649,13 +649,20 @@ where .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } else { // At least one regular token account - use invoke_signed with PDA seeds - let signer_seed_refs: Vec<&[u8]> = token_seeds.iter().map(|s| s.as_slice()).collect(); + let signer_seed_storage: Vec> = token_seeds + .iter() + .map(|seed_group| seed_group.iter().map(|seed| seed.as_slice()).collect()) + .collect(); + let signer_seed_refs: Vec<&[&[u8]]> = signer_seed_storage + .iter() + .map(|seed_group| seed_group.as_slice()) + .collect(); AI::invoke_cpi( &LIGHT_TOKEN_PROGRAM_ID, &transfer2_data, &account_metas, remaining_accounts, - &[signer_seed_refs.as_slice()], + signer_seed_refs.as_slice(), ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; } diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs index 153943e275..2b6fa37a7a 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/token.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/token.rs @@ -148,8 +148,9 @@ where ) .map_err(|e| LightSdkTypesError::ProgramError(e.into()))?; - // Push seeds for the Transfer2 CPI (needed for invoke_signed) - ctx.token_seeds.extend(seeds.iter().map(|s| s.to_vec())); + // Push one signer seed group per vault PDA for the later Transfer2 CPI. + ctx.token_seeds + .push(seeds.iter().map(|seed| seed.to_vec()).collect()); } // Push token data for the Transfer2 CPI (common for both ATA and regular paths) From 7ab1c909486b75162eca7d29c7445d611e50a9a5 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 14:50:22 +0000 Subject: [PATCH 6/9] Fix prover startup and decompression load flow --- prover/client/src/errors.rs | 1 - prover/client/src/prover.rs | 81 +++----- sdk-libs/client/src/interface/instructions.rs | 10 +- .../program-test/src/indexer/test_indexer.rs | 191 ++++++++++++++---- .../src/interface/account/token_seeds.rs | 4 +- .../interface/program/decompression/pda.rs | 16 +- .../tests/basic_test.rs | 9 +- 7 files changed, 200 insertions(+), 112 deletions(-) diff --git a/prover/client/src/errors.rs b/prover/client/src/errors.rs index e095bf3579..859cae32b8 100644 --- a/prover/client/src/errors.rs +++ b/prover/client/src/errors.rs @@ -39,7 +39,6 @@ pub enum ProverClientError { #[error("Integer conversion failed: {0}")] IntegerConversion(String), - #[error("Hashchain mismatch: computed {computed:?} != expected {expected:?} (batch_size={batch_size}, next_index={next_index})")] HashchainMismatch { computed: [u8; 32], diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index 5b5cfacc77..e0f9d4060d 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,13 +1,13 @@ use std::{ io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, - process::{Command, Stdio}, + process::Command, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; -use tracing::info; use tokio::time::sleep; +use tracing::info; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, @@ -16,6 +16,18 @@ use crate::{ }; static IS_LOADING: AtomicBool = AtomicBool::new(false); +const STARTUP_HEALTH_CHECK_RETRIES: usize = 300; + +fn has_http_ok_status(response: &[u8]) -> bool { + response + .split(|&byte| byte == b'\n') + .next() + .map(|status_line| { + status_line.starts_with(b"HTTP/") + && status_line.windows(5).any(|window| window == b" 200 ") + }) + .unwrap_or(false) +} pub(crate) fn build_http_client() -> Result { reqwest::Client::builder() @@ -59,7 +71,7 @@ fn health_check_once(timeout: Duration) -> bool { return health_check_once_with_curl(timeout); } - let mut response = [0u8; 512]; + let mut response = [0_u8; 512]; let bytes_read = match stream.read(&mut response) { Ok(bytes_read) => bytes_read, Err(error) => { @@ -68,14 +80,8 @@ fn health_check_once(timeout: Duration) -> bool { } }; - if bytes_read == 0 { - return false; - } - - let response = std::str::from_utf8(&response[..bytes_read]).unwrap_or_default(); - response.contains("200 OK") - || response.contains("{\"status\":\"ok\"}") - || health_check_once_with_curl(timeout) + bytes_read > 0 + && (has_http_ok_status(&response[..bytes_read]) || health_check_once_with_curl(timeout)) } fn prover_listener_present() -> bool { @@ -117,15 +123,15 @@ fn health_check_once_with_curl(timeout: Duration) -> bool { pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { - let prover_path: &str = { + let prover_path = { #[cfg(feature = "devenv")] { - &format!("{}/{}", _project_root.trim(), "cli/test_bin/run") + format!("{}/{}", _project_root.trim(), "cli/test_bin/run") } #[cfg(not(feature = "devenv"))] { println!("Running in production mode, using prover binary"); - "light" + "light".to_string() } }; @@ -137,50 +143,23 @@ pub async fn spawn_prover() { .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) .is_err() { - if health_check(120, 1).await { + if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { return; } panic!("Failed to start prover, health check failed."); } let spawn_result = async { - let mut command = Command::new(prover_path); - command.arg("start-prover").stdout(Stdio::piped()).stderr(Stdio::piped()); - let mut child = command.spawn().expect("Failed to start prover process"); - let mut child_exit_status = None; - - for _ in 0..120 { - if health_check(1, 1).await { - info!("Prover started successfully"); - return; - } - - if child_exit_status.is_none() { - match child.try_wait() { - Ok(Some(status)) => { - tracing::warn!( - ?status, - "prover launcher exited before health check succeeded; continuing to poll for detached prover" - ); - child_exit_status = Some(status); - } - Ok(None) => {} - Err(error) => { - tracing::error!(?error, "failed to poll prover child process"); - } - } - } - - sleep(Duration::from_secs(1)).await; - } - - if let Some(status) = child_exit_status { - panic!( - "Failed to start prover, health check failed after launcher exited with status {status}." - ); + Command::new(&prover_path) + .arg("start-prover") + .spawn() + .unwrap_or_else(|error| panic!("Failed to start prover process: {error}")); + + if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { + info!("Prover started successfully"); + } else { + panic!("Failed to start prover, health check failed."); } - - panic!("Failed to start prover, health check failed."); } .await; diff --git a/sdk-libs/client/src/interface/instructions.rs b/sdk-libs/client/src/interface/instructions.rs index f6d754b9b1..e80e7b72c1 100644 --- a/sdk-libs/client/src/interface/instructions.rs +++ b/sdk-libs/client/src/interface/instructions.rs @@ -8,7 +8,7 @@ use light_account::{ CompressedAccountData, InitializeLightConfigParams, Pack, UpdateLightConfigParams, }; use light_sdk::instruction::{ - account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, + account_meta::CompressedAccountMetaNoLamportsNoAddress, PackedAccounts, PackedStateTreeInfo, SystemAccountMetaConfig, ValidityProof, }; use light_token::constants::{ @@ -247,11 +247,15 @@ where // Process PDAs first, then tokens, to match on-chain split_at(token_accounts_offset). for &i in pda_indices.iter().chain(token_indices.iter()) { let (acc, data) = &cold_accounts[i]; - let _queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); - let tree_info = tree_infos + let proof_tree_info = tree_infos .get(i) .copied() .ok_or("tree info index out of bounds")?; + let queue_index = remaining_accounts.insert_or_get(acc.tree_info.queue); + let tree_info = PackedStateTreeInfo { + queue_pubkey_index: queue_index, + ..proof_tree_info + }; let packed_data = data.pack(&mut remaining_accounts)?; typed_accounts.push(CompressedAccountData { diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index e799d3f29e..6e61c3bd98 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -472,8 +472,6 @@ impl Indexer for TestIndexer { let account_data = account.value.ok_or(IndexerError::AccountNotFound)?; state_merkle_tree_pubkeys.push(account_data.tree_info.tree); } - println!("state_merkle_tree_pubkeys {:?}", state_merkle_tree_pubkeys); - println!("hashes {:?}", hashes); let mut proof_inputs = vec![]; let mut indices_to_remove = Vec::new(); @@ -495,14 +493,7 @@ impl Indexer for TestIndexer { .output_queue_elements .iter() .find(|(hash, _)| hash == compressed_account); - println!("queue_element {:?}", queue_element); - if let Some((_, index)) = queue_element { - println!("index {:?}", index); - println!( - "accounts.output_queue_batch_size {:?}", - accounts.output_queue_batch_size - ); if accounts.output_queue_batch_size.is_some() && accounts.leaf_index_in_queue_range(*index as usize)? { @@ -513,12 +504,7 @@ impl Indexer for TestIndexer { hash: *compressed_account, root: [0u8; 32], root_index: RootIndex::new_none(), - leaf_index: accounts - .output_queue_elements - .iter() - .position(|(x, _)| x == compressed_account) - .unwrap() - as u64, + leaf_index: *index, tree_info: light_client::indexer::TreeInfo { cpi_context: Some(accounts.accounts.cpi_context), tree: accounts.accounts.merkle_tree, @@ -2085,6 +2071,106 @@ impl TestIndexer { } } +#[cfg(all(test, feature = "v2"))] +mod tests { + use super::*; + use light_compressed_account::compressed_account::CompressedAccount; + + fn queued_account( + owner: [u8; 32], + merkle_tree: Pubkey, + queue: Pubkey, + leaf_index: u32, + ) -> CompressedAccountWithMerkleContext { + CompressedAccountWithMerkleContext { + compressed_account: CompressedAccount { + owner: owner.into(), + lamports: 0, + address: None, + data: None, + }, + merkle_context: MerkleContext { + merkle_tree_pubkey: merkle_tree.to_bytes().into(), + queue_pubkey: queue.to_bytes().into(), + leaf_index, + prove_by_index: false, + tree_type: TreeType::StateV2, + }, + } + } + + #[tokio::test] + async fn get_validity_proof_preserves_sparse_queue_leaf_indices() { + let merkle_tree = Pubkey::new_unique(); + let queue = Pubkey::new_unique(); + let sparse_leaf_indices = [5_u32, 1, 0, 4]; + + let compressed_accounts: Vec<_> = sparse_leaf_indices + .iter() + .enumerate() + .map(|(i, &leaf_index)| { + queued_account([i as u8 + 1; 32], merkle_tree, queue, leaf_index) + }) + .collect(); + let hashes: Vec<_> = compressed_accounts + .iter() + .map(|account| account.hash().unwrap()) + .collect(); + + let output_queue_elements = hashes + .iter() + .zip(sparse_leaf_indices.iter()) + .map(|(hash, &leaf_index)| (*hash, leaf_index as u64)) + .collect(); + + let indexer = TestIndexer { + state_merkle_trees: vec![StateMerkleTreeBundle { + rollover_fee: 0, + network_fee: 0, + merkle_tree: Box::new(MerkleTree::::new_with_history( + DEFAULT_BATCH_STATE_TREE_HEIGHT, + 0, + 0, + DEFAULT_BATCH_ROOT_HISTORY_LEN, + )), + accounts: StateMerkleTreeAccounts { + merkle_tree, + nullifier_queue: queue, + cpi_context: Pubkey::new_unique(), + tree_type: TreeType::StateV2, + }, + tree_type: TreeType::StateV2, + output_queue_elements, + input_leaf_indices: vec![], + output_queue_batch_size: Some(500), + num_inserted_batches: 0, + }], + address_merkle_trees: vec![], + payer: Keypair::new(), + governance_authority: Keypair::new(), + group_pda: Pubkey::new_unique(), + compressed_accounts, + nullified_compressed_accounts: vec![], + token_compressed_accounts: vec![], + token_nullified_compressed_accounts: vec![], + events: vec![], + onchain_pubkey_index: HashMap::new(), + }; + + let response = Indexer::get_validity_proof(&indexer, hashes, vec![], None) + .await + .unwrap(); + let leaf_indices: Vec = response + .value + .accounts + .iter() + .map(|account| account.leaf_index) + .collect(); + + assert_eq!(leaf_indices, sparse_leaf_indices.map(u64::from)); + } +} + impl TestIndexer { async fn process_inclusion_proofs( &self, @@ -2346,7 +2432,16 @@ impl TestIndexer { new_addresses.unwrap().len() ))); } - let client = Client::new(); + let client = Client::builder() + .no_proxy() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(120)) + .build() + .map_err(|error| { + IndexerError::CustomError(format!( + "failed to build prover HTTP client: {error}" + )) + })?; let (account_proof_inputs, address_proof_inputs, json_payload) = match (compressed_accounts, new_addresses) { (Some(accounts), None) => { @@ -2471,6 +2566,7 @@ impl TestIndexer { }; let mut retries = 3; + let mut last_error = "Failed to get proof from server".to_string(); while retries > 0 { let response_result = client .post(format!("{}{}", SERVER_ADDRESS, PROVE_PATH)) @@ -2478,33 +2574,50 @@ impl TestIndexer { .body(json_payload.clone()) .send() .await; - if let Ok(response_result) = response_result { - if response_result.status().is_success() { - let body = response_result.text().await.unwrap(); - let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); - return Ok(ValidityProofWithContext { - accounts: account_proof_inputs, - addresses: address_proof_inputs, - proof: CompressedProof { - a: proof_a, - b: proof_b, - c: proof_c, - } - .into(), - }); + match response_result { + Ok(response_result) => { + let status = response_result.status(); + let body = response_result.text().await.map_err(|error| { + IndexerError::CustomError(format!( + "failed to read prover response body: {error}" + )) + })?; + + if status.is_success() { + let proof_json = deserialize_gnark_proof_json(&body) + .map_err(|error| IndexerError::CustomError(error.to_string()))?; + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = + compress_proof(&proof_a, &proof_b, &proof_c); + return Ok(ValidityProofWithContext { + accounts: account_proof_inputs, + addresses: address_proof_inputs, + proof: CompressedProof { + a: proof_a, + b: proof_b, + c: proof_c, + } + .into(), + }); + } + + let body_preview: String = body.chars().take(512).collect(); + last_error = format!( + "prover returned HTTP {status} for validity proof request: {body_preview}" + ); } - } else { - println!("Error: {:#?}", response_result); + Err(error) => { + last_error = + format!("failed to contact prover for validity proof: {error}"); + } + } + + retries -= 1; + if retries > 0 { tokio::time::sleep(Duration::from_secs(5)).await; - retries -= 1; } } - Err(IndexerError::CustomError( - "Failed to get proof from server".to_string(), - )) + Err(IndexerError::CustomError(last_error)) } } } diff --git a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs index f22657590a..2bd0ee7bdc 100644 --- a/sdk-libs/sdk-types/src/interface/account/token_seeds.rs +++ b/sdk-libs/sdk-types/src/interface/account/token_seeds.rs @@ -265,7 +265,7 @@ where fn into_in_token_data( &self, tree_info: &PackedStateTreeInfo, - output_queue_index: u8, + _output_queue_index: u8, ) -> Result { Ok(MultiInputTokenDataWithContext { amount: self.token_data.amount, @@ -277,7 +277,7 @@ where root_index: tree_info.root_index, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: output_queue_index, + queue_pubkey_index: tree_info.queue_pubkey_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs index 8819724f59..cc7aa4ba1f 100644 --- a/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs +++ b/sdk-libs/sdk-types/src/interface/program/decompression/pda.rs @@ -142,22 +142,16 @@ where let pda_key = pda_account.key(); let address = derive_address(&pda_key, &ctx.light_config.address_space[0], ctx.program_id); - // 10. Build CompressedAccountInfo for CPI - // When PDA decompression is only the first phase of a later token Transfer2 execution, - // the stored input queue index must match that later execution basis, not the original - // packed proof basis. The mixed PDA+token flow uses `output_queue_index` for that. - let input_queue_index = if ctx.cpi_accounts.config().cpi_context { - output_queue_index - } else { - tree_info.queue_pubkey_index - }; - + // 10. Build CompressedAccountInfo for CPI. + // Input nullifiers must keep their original queue basis. The later system-program path + // groups nullifiers by queue index, so rewriting mixed PDA+token inputs onto a shared + // output queue drops whole tree/queue pairs from insertion. let input = InAccountInfo { data_hash: input_data_hash, lamports: 0, merkle_context: PackedMerkleContext { merkle_tree_pubkey_index: tree_info.merkle_tree_pubkey_index, - queue_pubkey_index: input_queue_index, + queue_pubkey_index: tree_info.queue_pubkey_index, leaf_index: tree_info.leaf_index, prove_by_index: tree_info.prove_by_index, }, diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs index 747c75ce32..585a828f03 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/basic_test.rs @@ -472,13 +472,12 @@ async fn test_create_pdas_and_mint_auto() { .await .expect("create_load_instructions should succeed"); - println!("all_instructions.len() = {:?}", all_instructions); - - // Expected: 1 PDA+Token ix + 2 ATA ixs (1 create_ata + 1 decompress) + 1 mint ix = 4 + // Expected: 1 mint load, 1 grouped PDA/token load, and 2 ATA instructions + // (create ATA + Transfer2 decompression) = 4 total. assert_eq!( all_instructions.len(), - 6, - "Should have 6 instructions: 1 PDA, 1 Token, 2 create_ata, 1 decompress_ata, 1 mint" + 4, + "Should have 4 instructions: 1 mint, 1 grouped PDA/token load, 1 create_ata, 1 ATA Transfer2" ); // Capture rent sponsor balance before decompression From 588127c462396a66ace5d00dba338da73a948de1 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 16:17:20 +0000 Subject: [PATCH 7/9] cleanup: harden prover startup polling --- prover/client/src/prover.rs | 99 +++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index e0f9d4060d..49b9f5aceb 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,7 +1,7 @@ use std::{ io::{Read, Write}, net::{TcpStream, ToSocketAddrs}, - process::Command, + process::{Child, Command}, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; @@ -39,15 +39,15 @@ pub(crate) fn build_http_client() -> Result } fn health_check_once(timeout: Duration) -> bool { - if prover_listener_present() { - return true; - } - let endpoint = SERVER_ADDRESS .strip_prefix("http://") .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) .unwrap_or(SERVER_ADDRESS); - let addr = match endpoint.to_socket_addrs().ok().and_then(|mut addrs| addrs.next()) { + let addr = match endpoint + .to_socket_addrs() + .ok() + .and_then(|mut addrs| addrs.next()) + { Some(addr) => addr, None => return false, }; @@ -64,8 +64,10 @@ fn health_check_once(timeout: Duration) -> bool { let _ = stream.set_write_timeout(Some(timeout)); let host = endpoint.split(':').next().unwrap_or("127.0.0.1"); - let request = - format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", HEALTH_CHECK, host); + let request = format!( + "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", + HEALTH_CHECK, host + ); if let Err(error) = stream.write_all(request.as_bytes()) { tracing::debug!(?error, "failed to write prover health request"); return health_check_once_with_curl(timeout); @@ -84,25 +86,6 @@ fn health_check_once(timeout: Duration) -> bool { && (has_http_ok_status(&response[..bytes_read]) || health_check_once_with_curl(timeout)) } -fn prover_listener_present() -> bool { - let endpoint = SERVER_ADDRESS - .strip_prefix("http://") - .or_else(|| SERVER_ADDRESS.strip_prefix("https://")) - .unwrap_or(SERVER_ADDRESS); - let port = endpoint.rsplit(':').next().unwrap_or("3001"); - - match Command::new("lsof") - .args(["-nP", &format!("-iTCP:{port}"), "-sTCP:LISTEN"]) - .output() - { - Ok(output) => output.status.success() && !output.stdout.is_empty(), - Err(error) => { - tracing::debug!(?error, "failed to execute lsof prover listener check"); - false - } - } -} - fn health_check_once_with_curl(timeout: Duration) -> bool { let timeout_secs = timeout.as_secs().max(1).to_string(); let url = format!("{}{}", SERVER_ADDRESS, HEALTH_CHECK); @@ -121,6 +104,46 @@ fn health_check_once_with_curl(timeout: Duration) -> bool { } } +async fn wait_for_prover_health( + retries: usize, + timeout: Duration, + child: &mut Child, +) -> Result<(), String> { + for attempt in 0..retries { + if health_check_once(timeout) { + return Ok(()); + } + + match child.try_wait() { + Ok(Some(status)) => { + return Err(format!( + "prover process exited before health check succeeded with status {status}" + )); + } + Ok(None) => {} + Err(error) => { + return Err(format!("failed to poll prover process status: {error}")); + } + } + + if attempt + 1 < retries { + sleep(timeout).await; + } + } + + Err(format!( + "prover health check failed after {} attempts", + retries + )) +} + +fn monitor_prover_child(mut child: Child) { + std::thread::spawn(move || match child.wait() { + Ok(status) => tracing::debug!(?status, "prover launcher exited"), + Err(error) => tracing::warn!(?error, "failed to wait on prover launcher"), + }); +} + pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { let prover_path = { @@ -150,15 +173,27 @@ pub async fn spawn_prover() { } let spawn_result = async { - Command::new(&prover_path) + let mut child = Command::new(&prover_path) .arg("start-prover") .spawn() .unwrap_or_else(|error| panic!("Failed to start prover process: {error}")); - if health_check(STARTUP_HEALTH_CHECK_RETRIES, 1).await { - info!("Prover started successfully"); - } else { - panic!("Failed to start prover, health check failed."); + match wait_for_prover_health( + STARTUP_HEALTH_CHECK_RETRIES, + Duration::from_secs(1), + &mut child, + ) + .await + { + Ok(()) => { + monitor_prover_child(child); + info!("Prover started successfully"); + } + Err(error) => { + let _ = child.kill(); + let _ = child.wait(); + panic!("Failed to start prover: {error}"); + } } } .await; From badacfa34ce66ec2560a7d6a97d0dc348af17938 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 16:47:21 +0000 Subject: [PATCH 8/9] format --- forester/src/processor/v2/helpers.rs | 18 +++-- forester/src/processor/v2/proof_worker.rs | 4 +- program-tests/utils/src/e2e_test_env.rs | 3 +- .../utils/src/mock_batched_forester.rs | 4 +- prover/client/tests/batch_address_append.rs | 67 +++++++++---------- .../program-test/src/indexer/test_indexer.rs | 3 +- .../tests/integration_tests.rs | 2 +- 7 files changed, 54 insertions(+), 47 deletions(-) diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs index a9fa2e290d..e6cd8092ea 100644 --- a/forester/src/processor/v2/helpers.rs +++ b/forester/src/processor/v2/helpers.rs @@ -502,9 +502,15 @@ impl StreamingAddressQueue { for (name, len) in [ ("addresses", data.addresses.len()), ("low_element_values", data.low_element_values.len()), - ("low_element_next_values", data.low_element_next_values.len()), + ( + "low_element_next_values", + data.low_element_next_values.len(), + ), ("low_element_indices", data.low_element_indices.len()), - ("low_element_next_indices", data.low_element_next_indices.len()), + ( + "low_element_next_indices", + data.low_element_next_indices.len(), + ), ] { if len < actual_end { return Err(anyhow!( @@ -545,9 +551,11 @@ impl StreamingAddressQueue { low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), low_element_indices: data.low_element_indices[start..actual_end].to_vec(), low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), - low_element_proofs: data.reconstruct_proofs::(start..actual_end).map_err( - |error| anyhow!("incomplete batch data: failed to reconstruct proofs: {error}"), - )?, + low_element_proofs: data + .reconstruct_proofs::(start..actual_end) + .map_err(|error| { + anyhow!("incomplete batch data: failed to reconstruct proofs: {error}") + })?, addresses, leaves_hashchain, })) diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs index ded9fcedc8..603fa3f19b 100644 --- a/forester/src/processor/v2/proof_worker.rs +++ b/forester/src/processor/v2/proof_worker.rs @@ -164,7 +164,9 @@ impl ProofClients { } } -pub fn spawn_proof_workers(config: &ProverConfig) -> crate::Result> { +pub fn spawn_proof_workers( + config: &ProverConfig, +) -> crate::Result> { let (job_tx, job_rx) = async_channel::bounded::(256); let clients = Arc::new(ProofClients::new(config)?); tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 6c9fdb5d5e..edb94fdf48 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -764,8 +764,7 @@ where // // local_leaves_hash_chain is only used for a test assertion. // let local_nullifier_hash_chain = create_hash_chain_from_array(&addresses); // assert_eq!(leaves_hash_chain, local_nullifier_hash_chain); - let start_index = - address_queue.tree_next_insertion_index as usize; + let start_index = address_queue.tree_next_insertion_index as usize; assert!( start_index >= 2, "start index should be greater than 2 else tree is not inited" diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 8ea51b7169..f3ad76cdbe 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -209,9 +209,7 @@ impl MockBatchedForester { &[], )?; let proof_client = ProofClient::local()?; - let proof_result = proof_client - .generate_batch_update_proof(inputs) - .await?; + let proof_result = proof_client.generate_batch_update_proof(inputs).await?; let new_root = self.merkle_tree.root(); let proof = CompressedProof { a: proof_result.0.proof.a, diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 8a02c363ff..7b8ceaa5f9 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -38,40 +38,39 @@ async fn prove_batch_address_append() { IndexedMerkleTree::::new(DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize, 0) .unwrap(); - let collect_non_inclusion_data = - |tree: &IndexedMerkleTree, values: &[BigUint]| { - let mut low_element_values = Vec::with_capacity(values.len()); - let mut low_element_indices = Vec::with_capacity(values.len()); - let mut low_element_next_indices = Vec::with_capacity(values.len()); - let mut low_element_next_values = Vec::with_capacity(values.len()); - let mut low_element_proofs: Vec< - [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize], - > = Vec::with_capacity(values.len()); - - for new_element_value in values { - let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); - - low_element_values.push(non_inclusion_proof.leaf_lower_range_value); - low_element_indices.push(non_inclusion_proof.leaf_index); - low_element_next_indices.push(non_inclusion_proof.next_index); - low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); - low_element_proofs.push( - non_inclusion_proof - .merkle_proof - .as_slice() - .try_into() - .unwrap(), - ); - } - - ( - low_element_values, - low_element_indices, - low_element_next_indices, - low_element_next_values, - low_element_proofs, - ) - }; + let collect_non_inclusion_data = |tree: &IndexedMerkleTree, + values: &[BigUint]| { + let mut low_element_values = Vec::with_capacity(values.len()); + let mut low_element_indices = Vec::with_capacity(values.len()); + let mut low_element_next_indices = Vec::with_capacity(values.len()); + let mut low_element_next_values = Vec::with_capacity(values.len()); + let mut low_element_proofs: Vec<[[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize]> = + Vec::with_capacity(values.len()); + + for new_element_value in values { + let non_inclusion_proof = tree.get_non_inclusion_proof(new_element_value).unwrap(); + + low_element_values.push(non_inclusion_proof.leaf_lower_range_value); + low_element_indices.push(non_inclusion_proof.leaf_index); + low_element_next_indices.push(non_inclusion_proof.next_index); + low_element_next_values.push(non_inclusion_proof.leaf_higher_range_value); + low_element_proofs.push( + non_inclusion_proof + .merkle_proof + .as_slice() + .try_into() + .unwrap(), + ); + } + + ( + low_element_values, + low_element_indices, + low_element_next_indices, + low_element_next_values, + low_element_proofs, + ) + }; let initial_start_index = relayer_merkle_tree.merkle_tree.rightmost_index; let initial_root = relayer_merkle_tree.root(); diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index 6e61c3bd98..584a684a59 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -2073,9 +2073,10 @@ impl TestIndexer { #[cfg(all(test, feature = "v2"))] mod tests { - use super::*; use light_compressed_account::compressed_account::CompressedAccount; + use super::*; + fn queued_account( owner: [u8; 32], merkle_tree: Pubkey, diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 9b40b900e5..2c3e82972a 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, + edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, D9EdgeMixedParams, }; From c65e87b6198fdb9b1b5c6cf9860507b9c4e5e790 Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 16 Mar 2026 17:11:55 +0000 Subject: [PATCH 9/9] format --- .../csdk-anchor-full-derived-test/tests/integration_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs index 2c3e82972a..9b40b900e5 100644 --- a/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs +++ b/sdk-tests/csdk-anchor-full-derived-test/tests/integration_tests.rs @@ -3863,7 +3863,7 @@ async fn test_d9_edge_many_literals() { #[tokio::test] async fn test_d9_edge_mixed() { use csdk_anchor_full_derived_test::d9_seeds::{ - edge_cases::{_UNDERSCORE_CONST, AB, SEED_123}, + edge_cases::{AB, SEED_123, _UNDERSCORE_CONST}, D9EdgeMixedParams, };