diff --git a/vm/rust/src/jsonrpc.rs b/vm/rust/src/jsonrpc.rs index a76a9d25aa..d29faa8bc7 100644 --- a/vm/rust/src/jsonrpc.rs +++ b/vm/rust/src/jsonrpc.rs @@ -1,28 +1,27 @@ use blockifier; -use blockifier::execution::entry_point::CallType; use blockifier::execution::call_info::OrderedL2ToL1Message; -use cairo_vm::vm::runners::builtin_runner::{ - BITWISE_BUILTIN_NAME, EC_OP_BUILTIN_NAME, HASH_BUILTIN_NAME, - POSEIDON_BUILTIN_NAME, RANGE_CHECK_BUILTIN_NAME, SIGNATURE_BUILTIN_NAME, KECCAK_BUILTIN_NAME, - SEGMENT_ARENA_BUILTIN_NAME, -}; +use blockifier::execution::entry_point::CallType; use blockifier::state::cached_state::TransactionalState; use blockifier::state::errors::StateError; use blockifier::state::state_api::{State, StateReader}; +use cairo_vm::vm::runners::builtin_runner::{ + BITWISE_BUILTIN_NAME, EC_OP_BUILTIN_NAME, HASH_BUILTIN_NAME, KECCAK_BUILTIN_NAME, + POSEIDON_BUILTIN_NAME, RANGE_CHECK_BUILTIN_NAME, SEGMENT_ARENA_BUILTIN_NAME, + SIGNATURE_BUILTIN_NAME, +}; use serde::Serialize; -use starknet_api::core::{ClassHash, ContractAddress, EntryPointSelector, PatriciaKey, EthAddress}; +use starknet_api::core::{ClassHash, ContractAddress, EntryPointSelector, EthAddress, PatriciaKey}; use starknet_api::deprecated_contract_class::EntryPointType; use starknet_api::hash::StarkFelt; use starknet_api::transaction::{Calldata, EventContent, L2ToL1Payload}; use starknet_api::transaction::{DeclareTransaction, Transaction as StarknetApiTransaction}; -use crate::juno_state_reader::JunoStateReader; - #[derive(Serialize, Default)] #[serde(rename_all = "UPPERCASE")] pub enum TransactionType { // dummy type for implementing Default trait - #[default] Unknown, + #[default] + Unknown, Invoke, Declare, #[serde(rename = "DEPLOY_ACCOUNT")] @@ -124,7 +123,7 @@ type BlockifierTxInfo = blockifier::transaction::objects::TransactionExecutionIn pub fn new_transaction_trace( tx: &StarknetApiTransaction, info: BlockifierTxInfo, - state: &mut TransactionalState, + state: &mut TransactionalState, ) -> Result { let mut trace = TransactionTrace::default(); let mut deprecated_declared_class: Option = None; @@ -225,14 +224,38 @@ impl From for ExecutionResources { } else { None }, - range_check_builtin_applications: val.builtin_instance_counter.get(RANGE_CHECK_BUILTIN_NAME).cloned(), - pedersen_builtin_applications: val.builtin_instance_counter.get(HASH_BUILTIN_NAME).cloned(), - poseidon_builtin_applications: val.builtin_instance_counter.get(POSEIDON_BUILTIN_NAME).cloned(), - ec_op_builtin_applications: val.builtin_instance_counter.get(EC_OP_BUILTIN_NAME).cloned(), - ecdsa_builtin_applications: val.builtin_instance_counter.get(SIGNATURE_BUILTIN_NAME).cloned(), - bitwise_builtin_applications: val.builtin_instance_counter.get(BITWISE_BUILTIN_NAME).cloned(), - keccak_builtin_applications: val.builtin_instance_counter.get(KECCAK_BUILTIN_NAME).cloned(), - segment_arena_builtin: val.builtin_instance_counter.get(SEGMENT_ARENA_BUILTIN_NAME).cloned(), + range_check_builtin_applications: val + .builtin_instance_counter + .get(RANGE_CHECK_BUILTIN_NAME) + .cloned(), + pedersen_builtin_applications: val + .builtin_instance_counter + .get(HASH_BUILTIN_NAME) + .cloned(), + poseidon_builtin_applications: val + .builtin_instance_counter + .get(POSEIDON_BUILTIN_NAME) + .cloned(), + ec_op_builtin_applications: val + .builtin_instance_counter + .get(EC_OP_BUILTIN_NAME) + .cloned(), + ecdsa_builtin_applications: val + .builtin_instance_counter + .get(SIGNATURE_BUILTIN_NAME) + .cloned(), + bitwise_builtin_applications: val + .builtin_instance_counter + .get(BITWISE_BUILTIN_NAME) + .cloned(), + keccak_builtin_applications: val + .builtin_instance_counter + .get(KECCAK_BUILTIN_NAME) + .cloned(), + segment_arena_builtin: val + .builtin_instance_counter + .get(SEGMENT_ARENA_BUILTIN_NAME) + .cloned(), } } } @@ -262,7 +285,9 @@ impl FunctionInvocation { } } +use crate::MemState; use blockifier::execution::call_info::CallInfo as BlockifierCallInfo; + impl From for FunctionInvocation { fn from(val: BlockifierCallInfo) -> Self { FunctionInvocation { @@ -327,7 +352,7 @@ impl From for OrderedMessage { pub struct Retdata(pub Vec); fn make_state_diff( - state: &mut TransactionalState, + state: &mut TransactionalState, deprecated_declared_class: Option, ) -> Result { let diff = state.to_state_diff(); diff --git a/vm/rust/src/juno_state_reader.rs b/vm/rust/src/juno_state_reader.rs index 44ed5bada9..6a0c4b0d94 100644 --- a/vm/rust/src/juno_state_reader.rs +++ b/vm/rust/src/juno_state_reader.rs @@ -193,6 +193,9 @@ pub fn contract_class_from_json_str(raw_json: &str) -> Result, serde_json::Error> = @@ -206,9 +210,9 @@ pub extern "C" fn cairoVMExecute( eth_l1_gas_price: felt_to_u128(gas_price_wei_felt), strk_l1_gas_price: felt_to_u128(gas_price_strk_felt), }, - None + None, ); - let mut state = CachedState::new(reader, GlobalContractCache::default()); + let mut state = CachedState::new(mem_state, GlobalContractCache::default()); let charge_fee = skip_charge_fee == 0; let validate = skip_validate == 0; @@ -273,19 +277,18 @@ pub extern "C" fn cairoVMExecute( Err(error) => { let err_string = match &error { ContractConstructorExecutionFailed(e) - | ExecutionError(e) - | ValidateTransactionError(e) => format!("{error} {e}"), - other => other.to_string() + | ExecutionError(e) + | ValidateTransactionError(e) => format!("{error} {e}"), + other => other.to_string(), }; report_error( reader_handle, format!( "failed txn {} reason: {}", - txn_and_query_bit.txn_hash, - err_string, + txn_and_query_bit.txn_hash, err_string, ) .as_str(), - txn_index as i64 + txn_index as i64, ); return; } @@ -293,16 +296,16 @@ pub extern "C" fn cairoVMExecute( if t.is_reverted() && err_on_revert != 0 { report_error( reader_handle, - format!("reverted: {}", t.revert_error.unwrap()) - .as_str(), - txn_index as i64 + format!("reverted: {}", t.revert_error.unwrap()).as_str(), + txn_index as i64, ); return; } // we are estimating fee, override actual fee calculation if !charge_fee { - t.actual_fee = calculate_tx_fee(&t.actual_resources, &block_context, &fee_type).unwrap(); + t.actual_fee = + calculate_tx_fee(&t.actual_resources, &block_context, &fee_type).unwrap(); } let actual_fee = t.actual_fee.0.into(); @@ -316,7 +319,7 @@ pub extern "C" fn cairoVMExecute( trace.err().unwrap() ) .as_str(), - txn_index as i64 + txn_index as i64, ); return; } @@ -410,8 +413,20 @@ fn build_block_context( // https://github.com/starknet-io/starknet-addresses/blob/df19b17d2c83f11c30e65e2373e8a0c65446f17c/bridged_tokens/mainnet.json fee_token_addresses: FeeTokenAddresses { // both addresses are the same for all networks - eth_fee_token_address: ContractAddress::try_from(StarkHash::try_from("0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7").unwrap()).unwrap(), - strk_fee_token_address: ContractAddress::try_from(StarkHash::try_from("0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d").unwrap()).unwrap(), + eth_fee_token_address: ContractAddress::try_from( + StarkHash::try_from( + "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", + ) + .unwrap(), + ) + .unwrap(), + strk_fee_token_address: ContractAddress::try_from( + StarkHash::try_from( + "0x04718f5a0fc34cc1af16a1cdee98ffb20c31f5cd61d6ab07201858f4287c938d", + ) + .unwrap(), + ) + .unwrap(), }, gas_prices, // fixed gas price, so that we can return "consumed gas" to Go side vm_resource_fee_cost: HashMap::from([ @@ -436,7 +451,10 @@ fn build_block_context( (KECCAK_BUILTIN_NAME.to_string(), N_STEPS_FEE_WEIGHT * 2048.0), ]) .into(), - invoke_tx_max_n_steps: max_steps.unwrap_or(MAX_STEPS_PER_TX as u64).try_into().unwrap(), + invoke_tx_max_n_steps: max_steps + .unwrap_or(MAX_STEPS_PER_TX as u64) + .try_into() + .unwrap(), validate_max_n_steps: MAX_VALIDATE_STEPS_PER_TX as u32, max_recursion_depth: 50, } diff --git a/vm/rust/src/mem_state.rs b/vm/rust/src/mem_state.rs new file mode 100644 index 0000000000..27b4a59b3c --- /dev/null +++ b/vm/rust/src/mem_state.rs @@ -0,0 +1,171 @@ +use blockifier::execution::contract_class::ContractClass; +use blockifier::state::cached_state::CommitmentStateDiff; +use blockifier::state::errors::StateError; +use blockifier::state::state_api::{State, StateReader, StateResult}; +use cached::{Cached, SizedCache}; +use once_cell::sync::Lazy; +use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; +use starknet_api::hash::StarkFelt; +use starknet_api::state::StorageKey; +use std::sync::Mutex; + +struct CachedContractClass { + pub definition: ContractClass, + pub cached_on_height: u64, +} + +static CLASS_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); +static STORAGE_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); +static CLASS_HASH_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); +static NONCE_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); +static COMPILED_CLASS_HASH: Lazy>> = + Lazy::new(|| Mutex::new(SizedCache::with_size(128))); + +pub struct MemState { + height: u64, +} + +impl MemState { + pub fn new(height: u64) -> Self { + Self { height } + } +} + +impl StateReader for MemState { + fn get_storage_at( + &mut self, + contract_address: ContractAddress, + key: StorageKey, + ) -> StateResult { + if let Some(value) = STORAGE_CACHE + .lock() + .unwrap() + .cache_get(&(contract_address, key)) + { + return Ok(value.clone()); + } + return Err(StateError::StateReadError(format!( + "failed to read location {} at address {}", + key.0.key(), + contract_address.0.key() + ))); + } + + fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { + if let Some(nonce) = NONCE_CACHE.lock().unwrap().cache_get(&contract_address) { + return Ok(nonce.clone()); + } + return Err(StateError::StateReadError(format!( + "failed to read nonce of address {}", + contract_address.0.key() + ))); + } + + fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { + if let Some(class_hash) = CLASS_HASH_CACHE + .lock() + .unwrap() + .cache_get(&contract_address) + { + return Ok(class_hash.clone()); + } + return Err(StateError::StateReadError(format!( + "failed to read class hash of address {}", + contract_address.0.key() + ))); + } + + fn get_compiled_contract_class( + &mut self, + class_hash: &ClassHash, + ) -> StateResult { + if let Some(cached_class) = CLASS_CACHE.lock().unwrap().cache_get(class_hash) { + if cached_class.cached_on_height < self.height { + return Ok(cached_class.definition.clone()); + } + } + return Err(StateError::UndeclaredClassHash(*class_hash)); + } + + fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult { + if let Some(compiled_class_hash) = + COMPILED_CLASS_HASH.lock().unwrap().cache_get(&class_hash) + { + return Ok(compiled_class_hash.clone()); + } + return Err(StateError::UndeclaredClassHash(class_hash.clone())); + } +} + +impl State for MemState { + fn set_storage_at( + &mut self, + contract_address: ContractAddress, + key: StorageKey, + value: StarkFelt, + ) { + let _ = STORAGE_CACHE + .lock() + .unwrap() + .cache_set((contract_address, key), value); + } + + fn increment_nonce(&mut self, contract_address: ContractAddress) -> StateResult<()> { + let current_nonce = self.get_nonce_at(contract_address)?; + let current_nonce_as_u64 = usize::try_from(current_nonce.0)? as u64; + let next_nonce_val = 1_u64 + current_nonce_as_u64; + let next_nonce = Nonce(StarkFelt::from(next_nonce_val)); + let _ = NONCE_CACHE + .lock() + .unwrap() + .cache_set(contract_address, next_nonce); + Ok(()) + } + + fn set_class_hash_at( + &mut self, + contract_address: ContractAddress, + class_hash: ClassHash, + ) -> StateResult<()> { + let _ = CLASS_HASH_CACHE + .lock() + .unwrap() + .cache_set(contract_address, class_hash); + Ok(()) + } + + fn set_contract_class( + &mut self, + class_hash: &ClassHash, + contract_class: ContractClass, + ) -> StateResult<()> { + let _ = CLASS_CACHE.lock().unwrap().cache_set( + *class_hash, + CachedContractClass { + definition: contract_class, + cached_on_height: self.height, + }, + ); + Ok(()) + } + + fn set_compiled_class_hash( + &mut self, + class_hash: ClassHash, + compiled_class_hash: CompiledClassHash, + ) -> StateResult<()> { + let _ = COMPILED_CLASS_HASH + .lock() + .unwrap() + .cache_set(class_hash, compiled_class_hash); + Ok(()) + } + + fn to_state_diff(&mut self) -> CommitmentStateDiff { + todo!() + } +}