diff --git a/Cargo.toml b/Cargo.toml index ecb5952..8e5c73c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "cargo-rcc", "circuit-examples", "rcc", + "rcc-lib", "rcc-halo2", "rcc-macro", "rcc-mockbuilder", @@ -15,3 +16,8 @@ halo2_proofs = { git = "https://github.com/powdr-org/halo2", branch = "kilic/shu [patch.crates-io] halo2_proofs = { git = "https://github.com/powdr-org/halo2", branch = "kilic/shuffle"} +[profile.dev.package."circuit-examples"] +debug = false +debug-assertions = false +overflow-checks = false + diff --git a/circuit-examples/Cargo.toml b/circuit-examples/Cargo.toml index 3f77709..4b1790c 100644 --- a/circuit-examples/Cargo.toml +++ b/circuit-examples/Cargo.toml @@ -6,6 +6,7 @@ license = "MIT OR Apache-2.0" [dependencies] rcc = { path = "../rcc" } +rcc-lib = { path = "../rcc-lib" } rcc-macro = { path = "../rcc-macro" } rcc-mockbuilder = { path = "../rcc-mockbuilder" } rcc-halo2 = { path = "../rcc-halo2" } diff --git a/circuit-examples/README.md b/circuit-examples/README.md index f52467d..74a7d2c 100644 --- a/circuit-examples/README.md +++ b/circuit-examples/README.md @@ -18,9 +18,12 @@ well as a circuit config file `example/halo2_example_config.toml`. To run mock proof generation with input `999`, run ``` -time cargo run --release --example halo2_example_runtime 999 +time cargo run --example halo2_example_runtime 999 ``` +Note that we recommend compiling and running the witness generator in `dev` +mode, as running it in `release` mode could result a 50x slower compilation time. + ## Witness generator speed test An example circuit is given in [`examples/mock_example.rs`](examples/mock_example.rs). @@ -37,7 +40,7 @@ This will generate the runtime library file for the circuit in set to `999`, run ``` -time cargo run --release --example mock_example_runtime 999 +time cargo run --example mock_example_runtime 999 ``` A Circom circuit of equivalent functionality is given in @@ -49,4 +52,3 @@ commands inside `example_circom`. time circom example.circom --wasm time node example_js/generate_witness.js example_js/example.wasm input.json outputs.wtns ``` - diff --git a/circuit-examples/examples/halo2_example.rs b/circuit-examples/examples/halo2_example.rs index 6d02383..8228d2c 100644 --- a/circuit-examples/examples/halo2_example.rs +++ b/circuit-examples/examples/halo2_example.rs @@ -17,15 +17,9 @@ fn gen(val: W) -> Vec<(W, W)> { fn my_circuit() { let val = input_wire("val"); - // let ab = gen(val); - // let c: Vec = ab.iter().map(|(ai, bi)| mul_seq(*ai, *bi)).collect(); - // let sum = sum(c); + let ab = gen(val); + let c: Vec = ab.iter().map(|(ai, bi)| mul_seq(*ai, *bi)).collect(); + let sum = sum(c); - val.declare_public("val"); - // sum.declare_public("sum"); - - let bits = val.to_bits_be_strict(); - for (i, b) in bits.iter().enumerate() { - b.declare_public(format!("{i}").as_str()); - } + sum.declare_public("sum"); } diff --git a/circuit-examples/examples/sha256_example.rs b/circuit-examples/examples/sha256_example.rs new file mode 100644 index 0000000..f8f16e6 --- /dev/null +++ b/circuit-examples/examples/sha256_example.rs @@ -0,0 +1,26 @@ +// use rcc_halo2::builder::{H2Wire as W, *}; +use rcc_mockbuilder::mock_builder::{MockWire as W, *}; +use rcc::traits::{UInt32, NaiveUInt32, BoolWire}; +use rcc_lib::sha256::*; + +type U32 = NaiveUInt32>; + +fn read_u32() -> U32 { + let v: Vec<_> = (0..32).map(|i| { + let w = input_wire(format!("{}-th bit", i).as_str()); + builder().assert_bool(w) + }).collect(); + U32 { repr: v.try_into().unwrap_or_else(|_| [Boolean::::from_const(0); 32]) } +} + +/// This circuit takes input a 4-byte string, interprets it as a U32 and hashes it via SHA256 +#[main_component] +fn my_circuit() { + let a = read_u32(); + + let hash: [U32; 8] = sha256(vec![a; 1]); + + for (i, u) in hash.iter().enumerate() { + u.to_dense().declare_public(format!("{i}").as_str()); + } +} diff --git a/circuit-examples/examples/sha256_example_runtime.rs b/circuit-examples/examples/sha256_example_runtime.rs new file mode 100644 index 0000000..09ed0c7 --- /dev/null +++ b/circuit-examples/examples/sha256_example_runtime.rs @@ -0,0 +1,55 @@ +#![allow(unused_imports)] +#![allow(unused_parens)] +#![allow(non_upper_case_globals)] +#![allow(unused_variables)] +#![allow(unused_mut)] + +mod sha256_example_runtime_lib; +use sha256_example_runtime_lib::generate_witnesses; + +use ark_bn254::Fr as F; +use rcc_mockbuilder::runtime::{BigUint, PrimeField}; + +/// Generated via ChatGPT +fn string_to_bool_vector(input: &str) -> Vec { + // Check if the input string has exactly 4 bytes + if input.len() != 4 { + panic!("Input string must be exactly 4 bytes long"); + } + + let mut bool_vector = Vec::with_capacity(32); // 4 bytes * 8 bits per byte = 32 bits + + for byte in input.bytes() { + // Convert the byte into a u8 + let byte_value = byte as u8; + + // Iterate through each bit in the byte (from left to right) + for i in (0..8).rev() { + // Check if the i-th bit is set (1) or not (0) + let bit_is_set = (byte_value >> i) & 1 == 1; + + // Push the result into the bool vector + bool_vector.push(bit_is_set); + } + } + + bool_vector +} + +fn main() { + let args: Vec = std::env::args().collect(); + let mut inputs = std::collections::HashMap::::new(); + + for (i, b) in string_to_bool_vector(args[1].as_str()).iter().enumerate() { + inputs.insert(format!("{i}-th bit"), F::from(*b)); + } + + let (witness, public) = generate_witnesses(inputs); + + print!("hash: "); + (0..8).for_each(|i| { + let bu: BigUint = public.get(format!("{i}").as_str()).unwrap().into_bigint().into(); + print!("{}", bu.to_str_radix(16)); + }); + println!(""); +} diff --git a/rcc-halo2/src/builder.rs b/rcc-halo2/src/builder.rs index 7c42b0b..ba920a1 100644 --- a/rcc-halo2/src/builder.rs +++ b/rcc-halo2/src/builder.rs @@ -2,7 +2,7 @@ #![allow(unused_must_use)] pub use rcc::{ - traits::{AlgBuilder, AlgWire, Boolean, ToBits, ToBitsBuilder}, + traits::{AlgBuilder, AlgWire, BoolWire, Boolean, ToBits, ToBitsBuilder}, Builder, WireLike, }; pub use rcc_macro::{component, component_of, main_component}; @@ -205,8 +205,8 @@ impl H2Builder { /// Fill the selector vector until it is of the same length as the witness vector fn fill_selectors(&mut self) { - let n = self.witness.len() - self.selectors.len(); - if n > 0 { + if self.witness.len() > self.selectors.len() { + let n = self.witness.len() - self.selectors.len(); self.selectors.extend((0..n).map(|_| 0)) } } @@ -241,22 +241,6 @@ impl H2Builder { b } - /// Add a new wire to the witness column that is constraint to `v` - pub fn new_constant_wire(&mut self, v: F) -> H2Wire { - let constant_index = if self.constants.contains_key(&v) { - *self.constants.get(&v).unwrap() - } else { - let l = self.constants.len(); - self.constants.insert(v, l); - l - }; - let w = self.new_wire(); - let us = format!("{}", v.into_bigint()); - self.composer.runtime(quote!( #w = F::from(BigInt!(#us)); )); - self.copys[2].offsets.push((w.row, constant_index)); - w - } - /// Compose runtime code that logs the value of a wire pub fn log(&mut self, wire: H2Wire) { self.runtime(quote! { @@ -430,6 +414,24 @@ impl AlgBuilder for H2Builder { type Constant = F; type Bool = Boolean; + /// Add a new wire to the witness column that is constraint to `v` + fn new_constant_wire(&mut self, v: F) -> H2Wire { + let constant_index = if self.constants.contains_key(&v) { + *self.constants.get(&v).unwrap() + } else { + let l = self.constants.len(); + self.constants.insert(v, l); + l + }; + let w = self.new_wire(); + let us = format!("{}", v.into_bigint()); + self.composer.runtime(quote!( #w = F::from(BigInt!(#us)); )); + self.copys[2].offsets.push((w.row, constant_index)); + self.fill_selectors(); + w + } + + #[component_of(self)] /// Add gadget fn add(&mut self, a: H2Wire, b: H2Wire) -> H2Wire { @@ -678,12 +680,31 @@ impl ToBitsBuilder for H2Builder { v } - fn from_bits_be(&mut self, _: Vec) -> Self { - todo!() + fn from_bits_be(&mut self, bits: Vec) -> Self::Wire { + let v = self.new_wire(); + let num_bits = bits.len(); + + let mut carry = F::from(1); + let mut pow2 = vec![]; + (0..num_bits).for_each(|_| { + pow2.push(self.new_constant_wire(carry)); + carry *= F::from(2); + }); + pow2.reverse(); + + for (i, bit) in bits.iter().enumerate() { + let alg_bit = bit.to_alg(); + let c = pow2[i]; + self.runtime(quote! { + #v = #v + #alg_bit * #c; + }); + } + v } - fn from_bits_le(&mut self, _: Vec) -> Self { - todo!() + fn from_bits_le(&mut self, mut bits: Vec) -> Self::Wire { + bits.reverse(); + self.from_bits_be(bits) } } diff --git a/rcc-halo2/src/runtime.rs b/rcc-halo2/src/runtime.rs index eabfa10..e9a4991 100644 --- a/rcc-halo2/src/runtime.rs +++ b/rcc-halo2/src/runtime.rs @@ -1,7 +1,7 @@ use polyexen::plaf::{ColumnWitness, Witness}; pub use ark_bn254::Fr as F; -pub use ark_ff::{BigInt, BigInteger, Field}; +pub use ark_ff::{BigInt, BigInteger, Field, PrimeField}; pub use halo2_proofs::halo2curves::bn256::Fr; pub use num_bigint::BigUint; diff --git a/rcc-lib/Cargo.toml b/rcc-lib/Cargo.toml new file mode 100644 index 0000000..665aed9 --- /dev/null +++ b/rcc-lib/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rcc-lib" +version = "0.0.1" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +rcc = { path = "../rcc" } diff --git a/rcc-lib/src/lib.rs b/rcc-lib/src/lib.rs new file mode 100644 index 0000000..fd0110a --- /dev/null +++ b/rcc-lib/src/lib.rs @@ -0,0 +1,2 @@ +pub mod sha256; +pub use sha256::*; diff --git a/rcc-lib/src/sha256.rs b/rcc-lib/src/sha256.rs new file mode 100644 index 0000000..a4da1c1 --- /dev/null +++ b/rcc-lib/src/sha256.rs @@ -0,0 +1,146 @@ +use rcc::traits::UInt32; +use rcc::{component_of, WithGlobalBuilder, Builder}; + +const H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 +]; + +const K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +]; + +#[component_of(T::global_builder())] +fn ch(x: T, y: T, z: T) -> T { + (x & y) ^ (!x & z) +} + +#[component_of(T::global_builder())] +fn maj(x: T, y: T, z: T) -> T { + (x & y) ^ (x & z) ^ (y & z) +} + +#[component_of(T::global_builder())] +fn bsigma0(x: T) -> T { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +#[component_of(T::global_builder())] +fn bsigma1(x: T) -> T { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} + +#[component_of(T::global_builder())] +fn ssigma0(x: T) -> T { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) +} + +#[component_of(T::global_builder())] +fn ssigma1(x: T) -> T { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) +} + +#[component_of(T::global_builder())] +pub fn sha256_compression(input: [T; 16], hin: [T; 8]) -> [T; 8] { + let mut w: Vec = Vec::with_capacity(64); + + let mut a = hin[0]; + let mut b = hin[1]; + let mut c = hin[2]; + let mut d = hin[3]; + let mut e = hin[4]; + let mut f = hin[5]; + let mut g = hin[6]; + let mut h = hin[7]; + + let k_vec: Vec = K.iter().map(|u| T::from_const(*u)).collect(); + let kk: [T; 64] = k_vec.try_into().unwrap_or_else(|_| unimplemented!()); + + for t in 0..64 { + let __for_loop = T::global_builder().new_context("for_loop".into()); + if t < 16 { + w.push(input[t]) + } else { + let __a = T::global_builder().new_context("a".into()); + w.push(ssigma1(w[t-2]) + w[t-7] + ssigma0(w[t-15]) + w[t-16]) + }; + + let mut t1 = w[t] + kk[t]; + + let __b = T::global_builder().new_context("b".into()); + t1 = t1 + h + bsigma1(e) + ch(e,f,g); + let t2 = bsigma0(a) + maj(a,b,c); + h = g; + g = f; + f = e; + e = d + t1; + d = c; + c = b; + b = a; + a = t1 + t2; + } + + [ + hin[0] + a, + hin[1] + b, + hin[2] + c, + hin[3] + d, + hin[4] + e, + hin[5] + f, + hin[6] + g, + hin[7] + h, + ] +} + +#[component_of(T::global_builder())] +pub fn sha256(input: Vec) -> [T; 8] { + let n_bits = input.len() * 32; + + let num_zeros = 512 - ((n_bits + 64) % 512) - 1; + + let length_hi = ((n_bits as u64) >> 32) as u32; + let length_lo = n_bits as u32; + + // println!("{length_hi}, {length_lo}"); println!("num_zeros: {num_zeros}"); + + let mut padded = input.clone(); + + if num_zeros >= 31 { + padded.push(UInt32::from_const(1u32 << 31)); + (0..num_zeros / 32).for_each(|_| { + padded.push(UInt32::from_const(0)); + }); + } else { + todo!() + } + + padded.push(UInt32::from_const(length_hi)); + padded.push(UInt32::from_const(length_lo)); + + // println!("length: {}", padded.len()); + + let blocks = padded.chunks(16); + + let mut state = [ + UInt32::from_const(H[0]), + UInt32::from_const(H[1]), + UInt32::from_const(H[2]), + UInt32::from_const(H[3]), + UInt32::from_const(H[4]), + UInt32::from_const(H[5]), + UInt32::from_const(H[6]), + UInt32::from_const(H[7]), + ]; + + for block in blocks { + state = sha256_compression(block.try_into().unwrap(), state) + } + + state +} diff --git a/rcc-macro/src/lib.rs b/rcc-macro/src/lib.rs index ee1c7f1..293ba2d 100644 --- a/rcc-macro/src/lib.rs +++ b/rcc-macro/src/lib.rs @@ -1,7 +1,7 @@ extern crate proc_macro; use proc_macro::TokenStream; use quote::{format_ident, quote}; -use syn::{parse, ItemFn}; +use syn::{parse, ItemFn, Expr}; fn prepend_code_to_function(code: proc_macro2::TokenStream, f: TokenStream) -> TokenStream { let f = parse::(f.clone()).unwrap(); @@ -31,10 +31,11 @@ pub fn component_of(builder_var: TokenStream, item: TokenStream) -> TokenStream let f = parse::(item.clone()).unwrap(); let name = format!("{}", f.sig.ident); let marker = format_ident!("__context_marker"); - let builder_var = format_ident!("{}", format!("{}", builder_var)); + + let builder_expr = parse::(builder_var).unwrap(); let code = quote! { - let #marker = #builder_var.new_context(#name.into()); + let #marker = #builder_expr.new_context(#name.into()); }; prepend_code_to_function(code, item) diff --git a/rcc-mockbuilder/Cargo.toml b/rcc-mockbuilder/Cargo.toml index 4b05829..299cc15 100644 --- a/rcc-mockbuilder/Cargo.toml +++ b/rcc-mockbuilder/Cargo.toml @@ -9,6 +9,7 @@ quote = "1.0" proc-macro2 = "1.0" ark-ff = "0.4.0" ark-bn254 = "0.4.0" +num-bigint = "^0.4" indexmap = "1.9.2" rcc = { path = "../rcc" } rcc-macro = { path = "../rcc-macro" } diff --git a/rcc-mockbuilder/src/lib.rs b/rcc-mockbuilder/src/lib.rs index 498762a..34b7eab 100644 --- a/rcc-mockbuilder/src/lib.rs +++ b/rcc-mockbuilder/src/lib.rs @@ -1,3 +1,4 @@ #![allow(unused_must_use)] +pub mod runtime; pub mod mock_builder; diff --git a/rcc-mockbuilder/src/mock_builder.rs b/rcc-mockbuilder/src/mock_builder.rs index 0b676ae..ad5fccb 100644 --- a/rcc-mockbuilder/src/mock_builder.rs +++ b/rcc-mockbuilder/src/mock_builder.rs @@ -1,15 +1,14 @@ use indexmap::IndexMap; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use rcc::{ - impl_alg_op, - runtime_composer::{Composer, RuntimeComposer, RuntimeWire}, -}; +use rcc::{runtime_composer::{Composer, RuntimeComposer, RuntimeWire}, traits::BoolWire}; pub use rcc::{ + impl_alg_op, + impl_to_bits, impl_global_builder, - traits::{AlgBuilder, AlgWire, Boolean}, Builder, WireLike, + traits::{ToBits, AlgBuilder, AlgWire, Boolean, ToBitsBuilder}, }; pub use rcc_macro::{component, component_of, main_component}; @@ -99,18 +98,6 @@ impl MockBuilder { } } - /// Allocated a constant wire - pub fn new_constant_wire(&mut self, v: F) -> MockWire { - let key = format!("{}", v.into_bigint()); - if self.constants.contains_key(&key) { - *self.constants.get(&key).unwrap() - } else { - let w = self.new_wire(); - self.constants.insert(key, w); - w - } - } - /// Compose runtime code that read an commandline argument into a wire pub fn arg_read(&mut self, wire: MockWire, index: usize) { self.runtime(quote! { @@ -128,11 +115,7 @@ impl MockBuilder { /// Returns a String encoding a closure that computes all the witnesses pub fn compose_rust_witness_gen(&mut self) -> String { let prelude = quote! { - use ark_ff::{BigInt, Field, PrimeField}; - use ark_bn254::Fr as F; - // runtime composer expects WireVal to be defined - type WireVal = F; - type AllWires = Vec; + use rcc_mockbuilder::runtime::*; }; let n = self.runtime_composer.wires.len(); @@ -178,6 +161,18 @@ impl AlgBuilder for MockBuilder { type Constant = F; type Bool = Boolean; + /// Allocated a constant wire + fn new_constant_wire(&mut self, v: F) -> MockWire { + let key = format!("{}", v.into_bigint()); + if self.constants.contains_key(&key) { + *self.constants.get(&key).unwrap() + } else { + let w = self.new_wire(); + self.constants.insert(key, w); + w + } + } + #[component_of(self)] /// Mock add gadget fn add(&mut self, a: MockWire, b: MockWire) -> MockWire { @@ -305,3 +300,96 @@ impl AlgBuilder for MockBuilder { } impl_global_builder!(MockBuilder, MockWire); + +impl ToBitsBuilder for MockBuilder { + const NUM_BITS: usize = 254; + + #[component_of(self)] + fn to_bits_be(&mut self, w: MockWire, num_bits: usize) -> Vec { + assert!(num_bits <= Self::NUM_BITS); + + let alg_bits = self.new_wires(num_bits); + + // Runtime code to compute be bits + self.runtime(quote! { + let u: BigUint = #w.into(); + let base2_bits = u.to_radix_be(2); + let mut bits = vec![]; + if base2_bits.len() <= #num_bits { + bits.extend((0..(#num_bits - base2_bits.len())).map(|_| F::from(0))); + bits.extend(base2_bits.iter().map(|&i| F::from(i))); + } else { + panic!("Error: number has more bits than expected.") + } + }); + + let bits = alg_bits + .iter() + .enumerate() + .map(|(i, &b)| { + self.runtime(quote! { + #b = bits[#i].into(); + }); + self.assert_bool(b) + }) + .collect(); + + let mut carry = F::from(1); + let mut pow2 = vec![]; + (0..num_bits).for_each(|_| { + pow2.push(self.new_constant_wire(carry)); + carry *= F::from(2); + }); + pow2.reverse(); + + w == self.inner_product(pow2, alg_bits); + + bits + } + + fn to_bits_be_strict(&mut self, w: MockWire) -> Vec { + // TODO: additionally constrain that `bits` is less than `p` + self.to_bits_be(w, Self::NUM_BITS) + } + + fn to_bits_le(&mut self, w: MockWire, num_bits: usize) -> Vec { + let mut v = self.to_bits_be(w, num_bits); + v.reverse(); + v + } + + fn to_bits_le_strict(&mut self, w: MockWire) -> Vec { + let mut v = self.to_bits_be_strict(w); + v.reverse(); + v + } + + fn from_bits_be(&mut self, bits: Vec) -> Self::Wire { + let v = self.new_wire(); + let num_bits = bits.len(); + + let mut carry = F::from(1); + let mut pow2 = vec![]; + (0..num_bits).for_each(|_| { + pow2.push(self.new_constant_wire(carry)); + carry *= F::from(2); + }); + pow2.reverse(); + + for (i, bit) in bits.iter().enumerate() { + let alg_bit = bit.to_alg(); + let c = pow2[i]; + self.runtime(quote! { + #v = #v + #alg_bit * #c; + }); + } + v + } + + fn from_bits_le(&mut self, mut bits: Vec) -> Self::Wire { + bits.reverse(); + self.from_bits_be(bits) + } +} + +impl_to_bits!(MockBuilder, MockWire); diff --git a/rcc-mockbuilder/src/runtime.rs b/rcc-mockbuilder/src/runtime.rs new file mode 100644 index 0000000..c801fd4 --- /dev/null +++ b/rcc-mockbuilder/src/runtime.rs @@ -0,0 +1,7 @@ +pub use ark_ff::{BigInt, BigInteger, Field, PrimeField}; +pub use ark_bn254::Fr as F; +pub use num_bigint::BigUint; + +// runtime composer expects WireVal to be defined +pub type WireVal = F; +pub type AllWires = Vec; diff --git a/rcc/src/impl_global_builder.rs b/rcc/src/impl_global_builder.rs index 89ef8ce..2e5aa8b 100644 --- a/rcc/src/impl_global_builder.rs +++ b/rcc/src/impl_global_builder.rs @@ -56,5 +56,13 @@ macro_rules! impl_global_builder { ) -> Vec { builder().smart_map(iter, f) } + + impl rcc::WithGlobalBuilder for $wire { + type Builder = $builder; + fn global_builder() -> &'static mut $builder { + builder() + } + } }; } + diff --git a/rcc/src/lib.rs b/rcc/src/lib.rs index 8a625e2..10e02bc 100644 --- a/rcc/src/lib.rs +++ b/rcc/src/lib.rs @@ -1,5 +1,8 @@ #![allow(unused_must_use)] +/// Re-exports +pub use rcc_macro::{component, component_of, main_component}; + use proc_macro2::TokenStream; use runtime_composer::Composer; @@ -7,6 +10,7 @@ pub mod impl_global_builder; pub mod runtime_composer; pub mod traits; + /// Any data structures or types over wires in a circuit should implement this trait pub trait WireLike: Sized + Copy + Clone { type Builder: Builder; @@ -14,6 +18,14 @@ pub trait WireLike: Sized + Copy + Clone { fn declare_public(self, _name: &str); } +/// This is a sub-trait that can be inherited to enable accessing a global builder +/// RCC provides a macro that implements this automaticaly +pub trait WithGlobalBuilder { + type Builder: 'static + Builder; + + fn global_builder() -> &'static mut ::Builder; +} + /// Circuit builder trait pub trait Builder { type Wire: Sized + Copy + Clone; diff --git a/rcc/src/traits/README.md b/rcc/src/traits/README.md index ae3362a..0289dee 100644 --- a/rcc/src/traits/README.md +++ b/rcc/src/traits/README.md @@ -11,12 +11,13 @@ RCC provides two categories of traits: wire traits and composer traits. ## Design Philosophy -1. Functions of wire traits **should never** take `&mut Composer` as input. We want +1. Functions of wire traits **should never** take `&mut Builder` as input. We want to keep the interfaces as clean as possible and rely on interior mutability - of wires (all wires encode their composer). + of wires (all wires encode their builder). 2. We support standard Rust operators for wire traits when they make sense semantically. 3. When the above is not possible, we define functions with names that provide - clear context on their semantics. e.g. `num_to_bits_be` instead of `num_to_bits`. + clear context on their semantics. e.g. `num_to_bits_be` (big endian) instead + of `num_to_bits`. 4. Expected behavior of trait functions must be clearly documented. diff --git a/rcc/src/traits/alg_bool.rs b/rcc/src/traits/alg_bool.rs index 4446aa0..29eee09 100644 --- a/rcc/src/traits/alg_bool.rs +++ b/rcc/src/traits/alg_bool.rs @@ -47,6 +47,9 @@ pub trait AlgWire: Neg + WireLike { + type Constant: From + From + From; + + fn from_const(_: Self::Constant) -> Self; fn inv_or_panic(self); fn inv_or_any(self); } @@ -55,6 +58,8 @@ pub trait AlgBuilder: Builder { type Constant: From + From + From + From; type Bool: BoolWire; + fn new_constant_wire(&mut self, a: Self::Constant) -> Self::Wire; + fn add(&mut self, a: Self::Wire, b: Self::Wire) -> Self::Wire; fn add_const(&mut self, a: Self::Wire, b: Self::Constant) -> Self::Wire; fn sub(&mut self, a: Self::Wire, b: Self::Wire) -> Self::Wire; @@ -210,7 +215,15 @@ macro_rules! impl_alg_op { } } + impl AlgWire for $wire { + type Constant = $constant_type; + + fn from_const(a: Self::Constant) -> Self { + use rcc::WithGlobalBuilder; + $wire::global_builder().new_constant_wire(a) + } + fn inv_or_panic(self) { self.builder().inv_or_panic(self); } @@ -233,6 +246,7 @@ pub trait BoolWire: { type AlgWire; + fn from_const(b: u32) -> Self; fn to_alg(&self) -> Self::AlgWire; fn then_or_else(&self, then: Self::AlgWire, els: Self::AlgWire) -> Self::AlgWire; } @@ -272,6 +286,15 @@ impl Not for Boolean { impl BoolWire for Boolean { type AlgWire = T; + fn from_const(b: u32) -> Self where ::Constant: From { + let u: u32 = if b > 0 { + 1 + } else { + 0 + }; + Boolean(T::from_const(u.into())) + } + fn to_alg(&self) -> T { self.0 } diff --git a/rcc/src/traits/mod.rs b/rcc/src/traits/mod.rs index 956d320..b2475c5 100644 --- a/rcc/src/traits/mod.rs +++ b/rcc/src/traits/mod.rs @@ -1,5 +1,8 @@ pub mod alg_bool; pub mod to_bits; +pub mod uint; pub use alg_bool::{AlgBuilder, AlgWire, BoolWire, Boolean}; pub use to_bits::{ToBits, ToBitsBuilder}; +pub use uint::*; + diff --git a/rcc/src/traits/to_bits.rs b/rcc/src/traits/to_bits.rs index f09d182..23476fd 100644 --- a/rcc/src/traits/to_bits.rs +++ b/rcc/src/traits/to_bits.rs @@ -3,7 +3,7 @@ use crate::{ Builder, }; -/// A trait indicating that a wire can be decomposed into bits. +/// A trait indicating that the builder supports decomposing a wire into bits pub trait ToBitsBuilder: AlgBuilder { /// Maximum num of bits required to represent the underlying field element, i.e. NUM_BITS is /// the smallest integer such that 2**NUM_BITS - 1 >= p. @@ -20,8 +20,8 @@ pub trait ToBitsBuilder: AlgBuilder { fn to_bits_le(&mut self, w: ::Wire, num_bits: usize) -> Vec; fn to_bits_le_strict(&mut self, w: ::Wire) -> Vec; - fn from_bits_be(&mut self, _: Vec) -> Self; - fn from_bits_le(&mut self, _: Vec) -> Self; + fn from_bits_be(&mut self, _: Vec) -> Self::Wire; + fn from_bits_le(&mut self, _: Vec) -> Self::Wire; } /// A trait indicating that a wire can be decomposed into bits. @@ -43,12 +43,16 @@ pub trait ToBits: AlgWire { fn to_bits_le(self, num_bits: usize) -> Vec; fn to_bits_le_strict(self) -> Vec; + + fn from_bits_be(_: Vec) -> Self; + fn from_bits_le(_: Vec) -> Self; } #[macro_export] /// Automatically implements AlgWire trait for AlgBuilder::Wire macro_rules! impl_to_bits { ($builder:ident, $wire:ident) => { + use rcc::WithGlobalBuilder; impl ToBits for $wire { type Bool = <$builder as AlgBuilder>::Bool; const NUM_BITS: usize = $builder::NUM_BITS; @@ -68,6 +72,14 @@ macro_rules! impl_to_bits { fn to_bits_le_strict(self) -> Vec { self.builder().to_bits_le_strict(self) } + + fn from_bits_be(bits: Vec) -> Self { + $wire::global_builder().from_bits_be(bits) + } + + fn from_bits_le(bits: Vec) -> Self { + $wire::global_builder().from_bits_le(bits) + } } }; } diff --git a/rcc/src/traits/uint.rs b/rcc/src/traits/uint.rs new file mode 100644 index 0000000..33c425a --- /dev/null +++ b/rcc/src/traits/uint.rs @@ -0,0 +1,306 @@ +use crate::{ + traits::{AlgWire, BoolWire, ToBits}, + WithGlobalBuilder +}; + +use std::ops::{Add, BitAnd, BitOr, BitXor, Mul, Not, Sub, Div, Shl, Shr}; + +use super::Boolean; + +/// Trait for an unsigned integer of arbitrary bitlength +pub trait UInt32: + Add + + Add + + Sub + + Sub + + Mul + + Mul + + Div + + Div + + Shl + + Shr + + BitAnd + + BitAnd + + BitOr + + BitOr + + BitXor + + BitXor + + Not + + Sized + + Copy +{ + type DenseRepr; + + fn from_const(_: u32) -> Self; + + fn num_bits(&self) -> usize; + fn rotate_right(&self, c: u32) -> Self { + assert!(c < 32, "Cannot rotate by more than 31 bits"); + (*self >> c) ^ (*self << (32 - c)) + } + fn to_dense(&self) -> Self::DenseRepr; +} + +/// An naive implementaiton of UInt32 using boolean wires +#[derive(Copy, Clone)] +pub struct NaiveUInt32 { + pub repr: [Bool; 32] +} + +impl WithGlobalBuilder for NaiveUInt32> { + type Builder = ::Builder; + + fn global_builder() -> &'static mut ::Builder { + T::global_builder() + } +} + +impl NaiveUInt32> { + pub fn from_vec(repr_vec: Vec>) -> Self { + NaiveUInt32 { repr: repr_vec.try_into().unwrap_or_else(|_| unreachable!()) } + } +} + +impl>> UInt32 for NaiveUInt32> { + type DenseRepr = T; + + fn num_bits(&self) -> usize { + self.repr.len() + } + + fn from_const(c: u32) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + Boolean::::from_const((c >> (31 - i)) & 1) + }).collect(); + NaiveUInt32::from_vec(repr_vec) + } + + fn rotate_right(&self, c: u32) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + if i >= c { + self.repr[(i - c) as usize] + } else { + self.repr[(i - c + 32) as usize] + } + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } + + fn to_dense(&self) -> T { + T::from_bits_be(self.repr.into()) + } +} + +impl>> Add for NaiveUInt32> { + type Output = Self; + + fn add(self, other: Self) -> Self { + let mut accum_bits = (0..32).map(|i| { + (self.repr[i].to_alg() + other.repr[i].to_alg()) * (1u32 << (31 - i)) + }); + + let mut sum = accum_bits.next().unwrap(); + for a in accum_bits { + sum = sum + a + } + + // Decompose the sum + let mut bits = sum.to_bits_be(33); + bits.remove(0); + NaiveUInt32::from_vec(bits) + } +} + +impl>> Add for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn add(self, other: C) -> Self { + self + Self::from_const(other.into()) + } +} + +impl>> Sub for NaiveUInt32> { + type Output = Self; + + fn sub(self, other: Self) -> Self { + // We use two's complement method here for substraction + self + !other + } +} + +impl>> Sub for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn sub(self, other: C) -> Self { + self + NaiveUInt32::from_const(!other.into()) + } +} + +impl Mul for NaiveUInt32> { + type Output = Self; + + fn mul(self, _other: Self) -> Self { + todo!() + } +} + +impl Mul for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn mul(self, _other: C) -> Self { + todo!() + } +} + +impl Div for NaiveUInt32> { + type Output = Self; + + fn div(self, _other: Self) -> Self { + todo!() + } +} + +impl Div for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn div(self, _other: C) -> Self { + todo!() + } +} + +impl>> BitAnd for NaiveUInt32> { + type Output = Self; + + fn bitand(self, other: Self) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + self.repr[i] & other.repr[i] + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +} + +impl>> BitAnd for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn bitand(self, other: C) -> Self { + let other = NaiveUInt32::from_const(other.into()); + self & other + } +} + +impl>> BitOr for NaiveUInt32> { + type Output = Self; + + fn bitor(self, other: Self) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + self.repr[i] | other.repr[i] + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +} + +impl>> BitOr for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn bitor(self, other: C) -> Self { + let other = NaiveUInt32::from_const(other.into()); + self | other + } +} + +impl>> BitXor for NaiveUInt32> { + type Output = Self; + + fn bitxor(self, other: Self) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + self.repr[i] ^ other.repr[i] + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +} + +impl>> BitXor for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn bitxor(self, other: C) -> Self { + let other = NaiveUInt32::from_const(other.into()); + self ^ other + } +} + +impl Shl for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn shl(self, c: C) -> Self { + let c: u32 = c.into(); + let repr_vec: Vec<_> = (0..32).map(|i| { + if i + c < 32 { + self.repr[(i + c) as usize] + } else { + Boolean::::from_const(0) + } + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +} + +impl Shr for NaiveUInt32> +where + C: Into, +{ + type Output = Self; + + fn shr(self, c: C) -> Self { + let c: u32 = c.into(); + let repr_vec: Vec<_> = (0..32).map(|i| { + if i >= c { + self.repr[(i - c) as usize] + } else { + Boolean::::from_const(0) + } + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +} + +impl Not for NaiveUInt32> { + type Output = Self; + + fn not(self) -> Self { + let repr_vec: Vec<_> = (0..32).map(|i| { + !self.repr[i] + }).collect(); + + NaiveUInt32::from_vec(repr_vec) + } +}