diff --git a/bitcoin-script-riscv/src/riscv/instructions.rs b/bitcoin-script-riscv/src/riscv/instructions.rs index a5d109c..12a1de7 100644 --- a/bitcoin-script-riscv/src/riscv/instructions.rs +++ b/bitcoin-script-riscv/src/riscv/instructions.rs @@ -6,6 +6,8 @@ use bitvmx_cpu_definitions::trace::TraceRWStep; use riscv_decode::Instruction; use riscv_decode::Instruction::*; +use crate::riscv::memory_alignment::clear_least_significant_bit; +use crate::riscv::memory_alignment::verify_alignment; use crate::ScriptValidation; use super::decoder::*; @@ -240,6 +242,7 @@ pub fn op_conditional( 0, ); let write_pc = ret[0]; + verify_alignment(stack, write_pc); let micro = stack.number(0); stack.rename(micro, "write_micro"); @@ -341,6 +344,7 @@ pub fn op_jal( let write_pc = add_with_bit_extension(stack, &tables, pc, imm, StackVariable::null()); stack.rename(write_pc, "write_pc"); + verify_alignment(stack, write_pc); let micro = stack.number(0); stack.rename(micro, "write_micro"); @@ -393,7 +397,9 @@ pub fn op_jalr( let write_pc = add_with_bit_extension(stack, &tables, trace_read.read_1_value, imm, bit_extension); + let write_pc = clear_least_significant_bit(stack, write_pc); stack.rename(write_pc, "write_pc"); + verify_alignment(stack, write_pc); let micro = stack.number(0); stack.rename(micro, "write_micro"); diff --git a/bitcoin-script-riscv/src/riscv/memory_alignment.rs b/bitcoin-script-riscv/src/riscv/memory_alignment.rs index 467c04e..3163dba 100644 --- a/bitcoin-script-riscv/src/riscv/memory_alignment.rs +++ b/bitcoin-script-riscv/src/riscv/memory_alignment.rs @@ -6,6 +6,13 @@ use super::{ script_utils::{number_u32_partial, WordTable}, }; +pub fn load_clear_lsb_table(stack: &mut StackTracker) -> StackVariable { + for i in (0..16).rev() { + stack.number(i & !1); + } + stack.join_in_stack(16, None, Some("clear_lsb_table")) +} + pub fn load_modulo_4_table(stack: &mut StackTracker) -> StackVariable { for i in (0..16).rev() { stack.number(i % 4); @@ -47,6 +54,26 @@ pub fn is_aligned( result } +pub fn verify_alignment(stack: &mut StackTracker, mem_address: StackVariable) { + let lower_half_nibble_table = load_lower_half_nibble_table(stack); + is_aligned(stack, mem_address, false, &lower_half_nibble_table); + stack.op_verify(); + stack.drop(lower_half_nibble_table); +} + +pub fn clear_least_significant_bit(stack: &mut StackTracker, mem_address: StackVariable) -> StackVariable { + let parts = stack.explode(mem_address); + let table = load_clear_lsb_table(stack); + + stack.move_var(parts[7]); + stack.get_value_from_table(table, None); + stack.to_altstack(); + stack.drop(table); + + stack.from_altstack(); + stack.join_count(parts[0], 7) +} + //get's the memory address to be read, and returns the aligned memory address and the alignment delta pub fn align_memory( stack: &mut StackTracker, diff --git a/emulator/src/decision/challenge.rs b/emulator/src/decision/challenge.rs index d54440b..588f9cd 100644 --- a/emulator/src/decision/challenge.rs +++ b/emulator/src/decision/challenge.rs @@ -1881,9 +1881,12 @@ mod tests { false, fail_execute, None, + None, + None, true, ForceCondition::No, ForceChallenge::No, + ForceChallenge::No, ); } @@ -1917,9 +1920,34 @@ mod tests { false, fail_execute, None, + None, + None, true, ForceCondition::No, ForceChallenge::No, + ForceChallenge::No, + ); + } + + #[test] + fn test_challenge_non_aligned_jump() { + init_trace(); + + let fail_mem_protection = FailConfiguration::new_fail_memory_protection(); + + test_challenge_aux( + "audit_15", + "audit_15.yaml", + 0, + true, + Some(fail_mem_protection), + None, + None, + None, + true, + ForceCondition::No, + ForceChallenge::No, + ForceChallenge::No, ); } } diff --git a/emulator/src/executor/fetcher.rs b/emulator/src/executor/fetcher.rs index 241decd..36b64f3 100644 --- a/emulator/src/executor/fetcher.rs +++ b/emulator/src/executor/fetcher.rs @@ -315,6 +315,11 @@ pub fn execute_step( } }; + let new_pc = program.pc.get_address(); + if !fail_config.fail_memory_protection && new_pc % 4 != 0 { + return Err(ExecutionResult::UnalignedJump(new_pc)); + } + let trace = TraceRWStep::new( program.step, read_1, @@ -491,7 +496,7 @@ pub fn op_jalr( ) }; - program.pc.jump(wrapping_add_itype(src_value, x)); + program.pc.jump(wrapping_add_itype(src_value, x) & !1); (read_1, TraceRead::default(), write_1, mem_witness) } diff --git a/emulator/src/lib.rs b/emulator/src/lib.rs index f4d6baa..c9530e2 100644 --- a/emulator/src/lib.rs +++ b/emulator/src/lib.rs @@ -60,6 +60,9 @@ pub enum ExecutionResult { #[error("Failed to verify the bitcoin script {0}")] BitcoinScriptVerification(#[from] ScriptValidation), + + #[error("Tried to jump to unaligned address: {0}")] + UnalignedJump(u32), } pub mod constants { diff --git a/emulator/tests/audit.rs b/emulator/tests/audit.rs index ad8a694..3675d9b 100644 --- a/emulator/tests/audit.rs +++ b/emulator/tests/audit.rs @@ -13,23 +13,22 @@ fn audit_tests() { if let Ok(path) = path { let fname = path.file_name(); let fname = fname.to_string_lossy(); - if fname.ends_with("verify.elf") && (fname.contains("09") || fname.contains("12")) || fname.contains("13") { + if fname.ends_with("verify.elf") { let path = path.path(); let path = path.to_string_lossy(); let (result, _) = verify_file(&format!("{}", path), true).unwrap(); match result { - ExecutionResult::Halt(exit_code, _) => { - assert!(exit_code == 0, "Error executing file {}", path); + ExecutionResult::Halt(0, _) => { info!("File {} executed successfully", path); count += 1; } - _ => assert!(false, "Error executing file {}", path), + _ => panic!("Error executing file {}", path), } } } } info!("Total files executed: {}", count); - assert_eq!(count, 3); + assert_eq!(count, 4); } diff --git a/emulator/tests/test_i_type_instructions.rs b/emulator/tests/test_i_type_instructions.rs index 316a3da..d19bdc9 100644 --- a/emulator/tests/test_i_type_instructions.rs +++ b/emulator/tests/test_i_type_instructions.rs @@ -54,7 +54,7 @@ fn test_jalr() { let _ = op_jalr(&x, &mut program); - assert_eq!(program.pc.get_address(), imm + rs1_value); + assert_eq!(program.pc.get_address(), (imm + rs1_value) & !1); assert_eq!(program.registers.get(rd), 4); }