From 03f4055f5f4fb6432f37c20c05912aea02fef0ff Mon Sep 17 00:00:00 2001 From: Zihan Date: Thu, 17 Jul 2025 18:01:35 -0400 Subject: [PATCH] bv: extend BitVecMutOps to support in place operations Signed-off-by: Zihan --- Cargo.toml | 1 + src/bv/ops.rs | 92 ++++++++++++++++++++++++++++++++++++++ tests/bitvec_arithmetic.rs | 56 ++++++++++++++++++++--- tests/bitvec_bit_ops.rs | 23 ++++++++++ 4 files changed, 165 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f068819..55ef185 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ rust-version = "1.73.0" # optional dependencies for crate interop fraction = { version = "0.15", optional = true } num-bigint = { version = "0.4", optional = true } +paste = "1.0.15" rand = { version = "0.8.5", optional = true, features = ["small_rng"] } serde = { version = "1.0", features = ["derive"], optional = true } diff --git a/src/bv/ops.rs b/src/bv/ops.rs index 6d446d5..40f8f24 100644 --- a/src/bv/ops.rs +++ b/src/bv/ops.rs @@ -465,10 +465,102 @@ pub trait BitVecOps { } } +/// Declares an arithmetic function which takes in two equal size bitvector and modifies +/// a bitvector of same width in place. +macro_rules! declare_in_place_arith_bin_fn { + ($name:ident) => { + paste::paste! { + fn [<$name _in_place>](&mut self, lhs: &impl BitVecOps, rhs: &impl BitVecOps) { + let width = self.width(); + debug_assert_eq!(width, lhs.width()); + debug_assert_eq!(lhs.width(), rhs.width()); + crate::bv::arithmetic::$name(self.words_mut(), lhs.words(), rhs.words(), width); + } + } + }; +} + +/// Declares an arithmetic function which takes in two equal size bitvector and modifies +/// a bitvector of same width in place. +macro_rules! declare_in_place_bit_bin_fn { + ($name:ident) => { + paste::paste! { + fn [<$name _in_place>](&mut self, lhs: &impl BitVecOps, rhs: &impl BitVecOps) { + debug_assert_eq!(self.width(), lhs.width()); + debug_assert_eq!(lhs.width(), rhs.width()); + crate::bv::arithmetic::$name(self.words_mut(), lhs.words(), rhs.words()); + } + } + }; +} + /// Operations over mutable bit-vector values. pub trait BitVecMutOps: BitVecOps { fn words_mut(&mut self) -> &mut [Word]; + declare_in_place_arith_bin_fn!(add); + declare_in_place_arith_bin_fn!(sub); + declare_in_place_arith_bin_fn!(shift_left); + declare_in_place_arith_bin_fn!(shift_right); + declare_in_place_arith_bin_fn!(arithmetic_shift_right); + + fn mul_in_place(&mut self, lhs: &impl BitVecOps, rhs: &impl BitVecOps) { + let width = self.width(); + match (self.words_mut(), lhs.words(), rhs.words()) { + ([dst], [a], [b]) => { + *dst = a.overflowing_mul(*b).0 & mask(width); + } + ([dst_lsb, dst_msb], [a_lsb, a_msb], [b_lsb, b_msb]) => { + [*dst_lsb, *dst_msb] = double_word_to_words( + double_word_from_words(*a_lsb, *a_msb) + .overflowing_mul(double_word_from_words(*b_lsb, *b_msb)) + .0 + & mask_double_word(width), + ); + } + (dst_words_mut, a_words, b_words) => { + crate::bv::arithmetic::mul(dst_words_mut, a_words, b_words, width); + } + } + } + + declare_in_place_bit_bin_fn!(and); + declare_in_place_bit_bin_fn!(or); + declare_in_place_bit_bin_fn!(xor); + + fn concat_in_place(&mut self, lhs: &impl BitVecOps, rhs: &impl BitVecOps) { + let width = self.width(); + debug_assert_eq!(width, lhs.width() + rhs.width()); + crate::bv::arithmetic::concat(self.words_mut(), lhs.words(), rhs.words(), rhs.width()); + } + + fn slice_in_place(&mut self, src: &impl BitVecOps, msb: WidthInt, lsb: WidthInt) { + let width = self.width(); + debug_assert_eq!(width, msb - lsb + 1); + crate::bv::arithmetic::slice(self.words_mut(), src.words(), msb, lsb); + } + + fn not_in_place(&mut self) { + self.words_mut().iter_mut().for_each(|word| *word = !*word); + } + + fn negate_in_place(&mut self) { + let width = self.width(); + crate::bv::arithmetic::negate_in_place(self.words_mut(), width); + } + + fn sign_extend_in_place(&mut self, src: &impl BitVecOps, by: WidthInt) { + let width = self.width(); + debug_assert_eq!(width, src.width() + by); + crate::bv::arithmetic::sign_extend(self.words_mut(), src.words(), src.width(), width); + } + + fn zero_extend_in_place(&mut self, src: &impl BitVecOps, by: WidthInt) { + let width = self.width(); + debug_assert_eq!(width, src.width() + by); + crate::bv::arithmetic::zero_extend(self.words_mut(), src.words()); + } + fn assign<'a>(&mut self, value: impl Into>) { let value = value.into(); debug_assert_eq!(self.width(), value.width()); diff --git a/tests/bitvec_arithmetic.rs b/tests/bitvec_arithmetic.rs index 5d7d952..389e480 100644 --- a/tests/bitvec_arithmetic.rs +++ b/tests/bitvec_arithmetic.rs @@ -13,14 +13,14 @@ use proptest::prelude::*; #[cfg(feature = "bigint")] fn do_test_arith( - a: BigInt, - b: BigInt, + a: &BigInt, + b: &BigInt, width: WidthInt, our: fn(&BitVecValue, &BitVecValue) -> BitVecValue, big: fn(BigInt, BigInt) -> BigInt, ) { - let a_vec = BitVecValue::from_big_int(&a, width); - let b_vec = BitVecValue::from_big_int(&b, width); + let a_vec = BitVecValue::from_big_int(a, width); + let b_vec = BitVecValue::from_big_int(b, width); let res = our(&a_vec, &b_vec); // check result @@ -31,19 +31,61 @@ fn do_test_arith( assert_eq!(expected, res, "{a} {b} {expected_num}"); } +#[cfg(feature = "bigint")] +fn do_test_arith_in_place( + a: &BigInt, + b: &BigInt, + width: WidthInt, + our: fn(&mut BitVecValue, &BitVecValue, &BitVecValue), + big: fn(BigInt, BigInt) -> BigInt, +) { + let a_vec = BitVecValue::from_big_int(a, width); + let b_vec = BitVecValue::from_big_int(b, width); + let mut res = BitVecValue::zero(width); + our(&mut res, &a_vec, &b_vec); + + // check result + let expected_mask = (BigInt::from(1) << width) - 1; + let expected_num: BigInt = big(a.clone(), b.clone()) & expected_mask; + // after masking, only the magnitude counts + let expected = BitVecValue::from_big_uint(expected_num.magnitude(), width); + assert_eq!(expected, res, "{a} {b} {expected_num}"); +} + #[cfg(feature = "bigint")] fn do_test_add(a: BigInt, b: BigInt, width: WidthInt) { - do_test_arith(a, b, width, |a, b| a.add(b), |a, b| a + b) + do_test_arith(&a, &b, width, |a, b| a.add(b), |a, b| a + b); + do_test_arith_in_place( + &a, + &b, + width, + |dst, a, b| dst.add_in_place(a, b), + |a, b| a + b, + ); } #[cfg(feature = "bigint")] fn do_test_sub(a: BigInt, b: BigInt, width: WidthInt) { - do_test_arith(a, b, width, |a, b| a.sub(b), |a, b| a - b) + do_test_arith(&a, &b, width, |a, b| a.sub(b), |a, b| a - b); + do_test_arith_in_place( + &a, + &b, + width, + |dst, a, b| dst.sub_in_place(a, b), + |a, b| a - b, + ); } #[cfg(feature = "bigint")] fn do_test_mul(a: BigInt, b: BigInt, width: WidthInt) { - do_test_arith(a, b, width, |a, b| a.mul(b), |a, b| a * b) + do_test_arith(&a, &b, width, |a, b| a.mul(b), |a, b| a * b); + do_test_arith_in_place( + &a, + &b, + width, + |dst, a, b| dst.mul_in_place(a, b), + |a, b| a * b, + ); } ////////////////////////// diff --git a/tests/bitvec_bit_ops.rs b/tests/bitvec_bit_ops.rs index 92c3898..77c94e2 100644 --- a/tests/bitvec_bit_ops.rs +++ b/tests/bitvec_bit_ops.rs @@ -20,6 +20,9 @@ fn do_test_concat(a: &str, b: &str) { let c_value = a_value.concat(&b_value); let expected = format!("{a}{b}"); assert_eq!(c_value.to_bit_str(), expected); + let mut c_value = BitVecValue::zero(c_value.width()); + c_value.concat_in_place(&a_value, &b_value); + assert_eq!(c_value.to_bit_str(), expected); } fn do_test_slice(src: &str, hi: WidthInt, lo: WidthInt) { @@ -38,6 +41,9 @@ fn do_test_slice(src: &str, hi: WidthInt, lo: WidthInt) { .take(res.width() as usize) .collect(); assert_eq!(res.to_bit_str(), expected); + let mut res = BitVecValue::zero(hi - lo + 1); + res.slice_in_place(&src_value, hi, lo); + assert_eq!(res.to_bit_str(), expected); } fn do_test_shift(src: &str, by: WidthInt, right: bool, signed: bool) { @@ -73,6 +79,17 @@ fn do_test_shift(src: &str, by: WidthInt, right: bool, signed: bool) { } let expected = BitVecValue::from_bit_str(&expected).unwrap(); assert_eq!(res, expected, "{src:?} {by} {res:?} {expected:?}"); + let mut res = BitVecValue::zero(a.width()); + if right { + if signed { + res.arithmetic_shift_right_in_place(&a, &b); + } else { + res.shift_right_in_place(&a, &b); + } + } else { + res.shift_left_in_place(&a, &b) + }; + assert_eq!(res, expected, "{src:?} {by} {res:?} {expected:?}"); } fn do_test_shift_right(src: &str, by: WidthInt) { @@ -100,6 +117,9 @@ fn do_test_zero_ext(src: &str, by: WidthInt) { actual.to_bit_str(), expected.to_bit_str() ); + let mut actual = BitVecValue::zero(value.width() + by); + actual.zero_extend_in_place(&value, by); + assert_eq!(actual, expected); } fn do_test_sign_ext(src: &str, by: WidthInt) { @@ -118,6 +138,9 @@ fn do_test_sign_ext(src: &str, by: WidthInt) { actual.to_bit_str(), expected.to_bit_str() ); + let mut actual = BitVecValue::zero(value.width() + by); + actual.sign_extend_in_place(&value, by); + assert_eq!(actual, expected); } //////////////////////////