diff --git a/Cargo.toml b/Cargo.toml index a0a0f09f..621ddf00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,6 @@ ark-serialize = "0.5" ark-std = { version = "0.5", features = ["std"] } spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", -] } -spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { git = "https://github.com/WizardOfMenlo/whir/", features = ["tracing"], rev = "3d627d31cec7d73a470a31a913229dd3128ee0cf" } +], rev = "ecb4f08373ed930175585c856517efdb1851fb47" } +spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish", rev = "ecb4f08373ed930175585c856517efdb1851fb47" } +whir = { git = "https://github.com/WizardOfMenlo/whir/", features = ["tracing"], rev = "c2bafc36a878500a3e19c12d3239d488ff7b5d61" } \ No newline at end of file diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index e08ac26f..b7509738 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -1,9 +1,8 @@ use { - provekit_common::{utils::next_power_of_two, FieldElement, WhirConfig, WhirR1CSScheme, R1CS}, - whir::parameters::{ + provekit_common::{utils::next_power_of_two, WhirConfig, WhirR1CSScheme, R1CS}, std::sync::Arc, whir::{ntt::RSDefault, parameters::{ default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, MultivariateParameters, ProtocolParameters, SoundnessType, - }, + }} }; // Minimum log2 of the WHIR evaluation domain (lower bound for m). @@ -63,6 +62,8 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { deduplication_strategy: DeduplicationStrategy::Disabled, merkle_proof_strategy: MerkleProofStrategy::Uncompressed, }; - WhirConfig::new(mv_params, whir_params) + let reed_solomon = Arc::new(RSDefault); + let basefield_reed_solomon = reed_solomon.clone(); + WhirConfig::new(reed_solomon, basefield_reed_solomon, mv_params, whir_params) } } diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 68927073..8f2b9c4e 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -51,7 +51,6 @@ func (circuit *Circuit) Define(api frontend.API) error { } rootHash, batchingRandomness, initialOODQueries, initialOODAnswers, err := parseBatchedCommitment(arthur, circuit.WHIRParamsWitness) - if err != nil { return err } @@ -68,7 +67,6 @@ func (circuit *Circuit) Define(api frontend.API) error { } whirFoldingRandomness, err := RunZKWhir(api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRound, circuit.WHIRParamsWitness, [][]frontend.Variable{circuit.WitnessClaimedEvaluations, circuit.WitnessBlindingEvaluations}, circuit.WitnessLinearStatementEvaluations, batchingRandomness, initialOODQueries, initialOODAnswers, rootHash) - if err != nil { return err } @@ -88,8 +86,8 @@ func (circuit *Circuit) Define(api frontend.API) error { func verifyCircuit( deferred []Fp256, cfg Config, hints Hints, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, ) error { - transcriptT := make([]uints.U8, cfg.TranscriptLen) - contTranscript := make([]uints.U8, cfg.TranscriptLen) + transcriptT := make([]uints.U8, len(cfg.Transcript)) + contTranscript := make([]uints.U8, len(cfg.Transcript)) for i := range cfg.Transcript { transcriptT[i] = uints.NewU8(cfg.Transcript[i]) diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 52acea04..c2ea01b5 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "math/bits" "github.com/consensys/gnark/backend/groth16" gnarkNimue "github.com/reilabs/gnark-nimue" @@ -14,6 +15,248 @@ import ( "reilabs/whir-verifier-circuit/app/common" ) +type IndexPair struct { + depth uint64 + index uint64 +} + +// func convertMultiIndexMTProofsToFullMultiPath( +// merklePaths []MultiIndexMerkleTreeProof[Digest], +// stirAnswers [][][]Fp256, +// ) []FullMultiPathWithCapping[Digest] { +// fullMerklePaths := make([]FullMultiPathWithCapping[Digest], 0, len(merklePaths)) +// for mIndex, mp := range merklePaths { +// depth := mp.Depth +// proofIter := 0 + +// currentMPAnswers := stirAnswers[mIndex] +// if len(currentMPAnswers) != len(mp.Indices) { +// panic(fmt.Sprintf("mismatched stirAnswers (%d) and indices (%d)", len(currentMPAnswers), len(mp.Indices))) +// } + +// leafHashes := make(map[uint64]Digest) +// for i := range mp.Indices { +// leafHashes[mp.Indices[i]] = HashLeafData(currentMPAnswers[i]) +// } + +// uniqueIndices := make(map[uint64]bool) +// for _, idx := range mp.Indices { +// uniqueIndices[idx] = true +// } + +// indices := make([]uint64, 0, len(uniqueIndices)) +// for idx := range uniqueIndices { +// indices = append(indices, idx) +// } +// sort.Slice(indices, func(i, j int) bool { +// return indices[i] < indices[j] +// }) + +// treeElements := make(map[IndexPair]Digest, len(indices)) + +// cappedDepth := bits.Len(uint(len(mp.Indices))) - 1 + +// if cappedDepth >= int(mp.Depth) { +// // We need to have at least one level, in order to produce the hash and consume the siblings +// cappedDepth = int(mp.Depth) - 1 +// } +// capContainer := make([]Digest, 2< 0; d-- { +// nextIndices := make([]uint64, 0, len(indices)) +// capIndices := make([]uint64, 0, 1<= len(mp.Proof) { +// panic("insufficient siblings") +// } +// sib := mp.Proof[proofIter] +// capIndices = append(capIndices, idx) +// capIndices = append(capIndices, idx+1) +// treeElements[IndexPair{depth: d, index: idx + 1}] = sib +// treeElements[IndexPair{depth: d - 1, index: idx / 2}] = HashTwoDigests(node, sib) +// proofIter++ +// nextIndices = append(nextIndices, idx/2) +// i++ +// } +// } else { +// // right child +// if proofIter >= len(mp.Proof) { +// panic("insufficient siblings") +// } +// sib := mp.Proof[proofIter] +// capIndices = append(capIndices, idx-1) +// capIndices = append(capIndices, idx) +// treeElements[IndexPair{depth: d, index: idx - 1}] = sib +// treeElements[IndexPair{depth: d - 1, index: idx / 2}] = HashTwoDigests(sib, node) + +// proofIter++ +// nextIndices = append(nextIndices, idx/2) +// i++ +// } +// } +// if d <= (uint64(cappedDepth)) { +// for j := range capIndices { +// offset := 1 << d +// capContainer[int(offset)+int(capIndices[j])] = treeElements[IndexPair{depth: d, index: uint64(capIndices[j])}] +// } +// } +// indices = nextIndices +// } + +// capContainer[1] = treeElements[IndexPair{depth: 0, index: 0}] + +// var paths []Path[Digest] +// for _, origIdx := range mp.Indices { +// leafSibling, authPath, err := ExtractAuthPath(treeElements, origIdx, depth) +// if err != nil { +// panic(fmt.Sprintf("failed to extract auth path for index %d: %v", origIdx, err)) +// } + +// paths = append(paths, Path[Digest]{ +// LeafIndex: origIdx, +// LeafSiblingHash: leafSibling, +// AuthPath: authPath[:len(authPath)-cappedDepth], +// }) +// } + +// fullMerklePaths = append(fullMerklePaths, FullMultiPathWithCapping[Digest]{Proofs: paths, CapContainer: capContainer}) +// } +// return fullMerklePaths +// } + +func convertFullMultiPathToFullMultiPathWithCapping( + paths []FullMultiPath[Digest], + stirAnswers [][][]Fp256, +) ([]FullMultiPathWithCapping[Digest], error) { + fullMerklePaths := make([]FullMultiPathWithCapping[Digest], 0, len(paths)) + for mIndex, path := range paths { + currentMPAnswers := stirAnswers[mIndex] + + if len(currentMPAnswers) != len(path.Proofs) { + return nil, fmt.Errorf("mismatched stirAnswers (%d) and indices (%d)", len(currentMPAnswers), len(path.Proofs)) + } + depth := len(path.Proofs[0].AuthPath) + 1 + + cappedDepth := 0 + if len(path.Proofs) > 1 { + cappedDepth = bits.Len(uint(len(path.Proofs))) - 1 + } + if cappedDepth >= depth { + cappedDepth = depth - 1 + } + if cappedDepth < 0 { + cappedDepth = 0 + } + tree := make(map[IndexPair]Digest, 2<= 0; idx-- { + sibling := proof.AuthPath[idx] + siblingIdx := currentIdx ^ 1 + tree[IndexPair{depth: uint64(currentDepth), index: siblingIdx}] = sibling + + if currentIdx%2 == 0 { + currentHash = HashTwoDigests(currentHash, sibling) + } else { + currentHash = HashTwoDigests(sibling, currentHash) + } + + currentIdx /= 2 + currentDepth-- + + tree[IndexPair{depth: uint64(currentDepth), index: currentIdx}] = currentHash + } +} + func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, buildOps common.BuildOps) error { io := gnarkNimue.IOPattern{} err := io.Parse([]byte(config.IOPattern)) @@ -24,7 +267,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v var pointer uint64 var truncated []byte - var merklePaths []FullMultiPath[KeccakDigest] + var merklePaths []FullMultiPath[Digest] var stirAnswers [][][]Fp256 var deferred []Fp256 var claimedEvaluations ClaimedEvaluations @@ -45,7 +288,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v switch string(op.Label) { case "merkle_proof": - var path FullMultiPath[KeccakDigest] + var path FullMultiPath[Digest] _, err = arkSerialize.CanonicalDeserializeWithMode( bytes.NewReader(config.Transcript[start:end]), &path, @@ -121,14 +364,20 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v return fmt.Errorf("failed to deserialize interner: %w", err) } - var hidingSpartanData = consumeWhirData(config.WHIRConfigHidingSpartan, &merklePaths, &stirAnswers) + // convertMultiIndexMTProofsToFullMultiPath(merklePaths, stirAnswers, &fullMerklePaths) + fullMerklePaths, err := convertFullMultiPathToFullMultiPathWithCapping(merklePaths, stirAnswers) + if err != nil { + return fmt.Errorf("failed to convert full multi path to full multi path with capping: %w", err) + } - var witnessData = consumeWhirData(config.WHIRConfigWitness, &merklePaths, &stirAnswers) + var hidingSpartanData = consumeWhirData(config.WHIRConfigHidingSpartan, &fullMerklePaths, &stirAnswers) + var witnessData = consumeWhirData(config.WHIRConfigWitness, &fullMerklePaths, &stirAnswers) hints := Hints{ witnessHints: witnessData, spartanHidingHint: hidingSpartanData, } + err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, r1cs, interner, buildOps) if err != nil { return fmt.Errorf("verification failed: %w", err) @@ -180,3 +429,99 @@ func GetR1csFromUrl(r1csUrl string) ([]byte, error) { log.Printf("Successfully downloaded") return r1csFile, nil } + +func ExtractAuthPath( + treeElements map[IndexPair]Digest, + leafIndex uint64, + depth uint64, +) (leafSiblingHash Digest, authPath []Digest, err error) { + leafSiblingIdx := leafIndex ^ 1 + leafSibling, ok := treeElements[IndexPair{depth: depth, index: leafSiblingIdx}] + if !ok { + return Digest{}, nil, fmt.Errorf("missing leaf sibling at depth=%d, index=%d", depth, leafSiblingIdx) + } + + authPath = make([]Digest, 0, depth-1) + currentIdx := leafIndex + + for d := depth - 1; d >= 1; d-- { + parentIdx := currentIdx / 2 + siblingIdx := parentIdx ^ 1 + + sibling, ok := treeElements[IndexPair{depth: d, index: siblingIdx}] + if !ok { + return Digest{}, nil, fmt.Errorf("missing sibling at depth=%d, index=%d (parent=%d)", d, siblingIdx, parentIdx) + } + + authPath = append(authPath, sibling) + currentIdx = parentIdx + } + + return leafSibling, authPath, nil +} + +func VerifyAuthPath( + leafHash Digest, + leafSiblingHash Digest, + authPath []Digest, + leafIndex uint64, + depth uint64, + expectedRoot Digest, +) error { + var currentHash Digest + if leafIndex%2 == 0 { + currentHash = HashTwoDigests(leafHash, leafSiblingHash) + } else { + currentHash = HashTwoDigests(leafSiblingHash, leafHash) + } + + currentIdx := leafIndex + for level := 0; level < len(authPath); level++ { + parentIdx := currentIdx / 2 + sibling := authPath[level] + if parentIdx%2 == 0 { + currentHash = HashTwoDigests(currentHash, sibling) + } else { + currentHash = HashTwoDigests(sibling, currentHash) + } + currentIdx = parentIdx + } + if currentHash != expectedRoot { + return fmt.Errorf("root mismatch: got %x, expected %x", currentHash, expectedRoot) + } + + return nil +} + +func TestExtractAndVerifyAuthPaths( + treeElements map[IndexPair]Digest, + leafHashes map[uint64]Digest, + indices []uint64, + depth uint64, +) error { + root, ok := treeElements[IndexPair{depth: 0, index: 0}] + if !ok { + return fmt.Errorf("root not found in treeElements") + } + + for _, idx := range indices { + leafSibling, authPath, err := ExtractAuthPath(treeElements, idx, depth) + if err != nil { + return fmt.Errorf("failed to extract auth path for index %d: %w", idx, err) + } + + leafHash, ok := leafHashes[idx] + if !ok { + return fmt.Errorf("leaf hash not found for index %d", idx) + } + + err = VerifyAuthPath(leafHash, leafSibling, authPath, idx, depth, root) + if err != nil { + return fmt.Errorf("failed to verify auth path for index %d: %w", idx, err) + } + + fmt.Printf("✓ Index %d: verified successfully (auth path length: %d)\n", idx, len(authPath)) + } + + return nil +} diff --git a/recursive-verifier/app/circuit/matrix_evaluation.go b/recursive-verifier/app/circuit/matrix_evaluation.go index adaa8466..3dedfa23 100644 --- a/recursive-verifier/app/circuit/matrix_evaluation.go +++ b/recursive-verifier/app/circuit/matrix_evaluation.go @@ -46,7 +46,7 @@ func evaluateR1CSMatrixExtension(api frontend.API, circuit *Circuit, rowRand []f rowEval := calculateEQOverBooleanHypercube(api, rowRand) colEval := calculateEQOverBooleanHypercube(api, colRand) - for i := range len(circuit.MatrixA) { + for i := range circuit.MatrixA { ansA = api.Add(ansA, api.Mul(circuit.MatrixA[i].value, api.Mul(rowEval[circuit.MatrixA[i].row], colEval[circuit.MatrixA[i].column]))) } for i := range circuit.MatrixB { diff --git a/recursive-verifier/app/circuit/mt.go b/recursive-verifier/app/circuit/mt.go index 4930c399..2c342f31 100644 --- a/recursive-verifier/app/circuit/mt.go +++ b/recursive-verifier/app/circuit/mt.go @@ -15,33 +15,41 @@ func newMerkle( var totalLeaves = make([][][]frontend.Variable, len(hint.merklePaths)) var totalLeafSiblingHashes = make([][]frontend.Variable, len(hint.merklePaths)) var totalLeafIndexes = make([][]uints.U64, len(hint.merklePaths)) + var totalCapContainer = make([][]frontend.Variable, len(hint.merklePaths)) for i, merkle_path := range hint.merklePaths { var numOfLeavesProved = len(merkle_path.Proofs) var treeHeight = len(merkle_path.Proofs[0].AuthPath) - totalAuthPath[i] = make([][]frontend.Variable, numOfLeavesProved) + if treeHeight > 0 { + totalAuthPath[i] = make([][]frontend.Variable, numOfLeavesProved) + } totalLeaves[i] = make([][]frontend.Variable, numOfLeavesProved) totalLeafSiblingHashes[i] = make([]frontend.Variable, numOfLeavesProved) + totalCapContainer[i] = make([]frontend.Variable, len(merkle_path.CapContainer)) for j := range numOfLeavesProved { - totalAuthPath[i][j] = make([]frontend.Variable, treeHeight) + if treeHeight > 0 { + totalAuthPath[i][j] = make([]frontend.Variable, treeHeight) + } totalLeaves[i][j] = make([]frontend.Variable, len(hint.stirAnswers[i][j])) } totalLeafIndexes[i] = make([]uints.U64, numOfLeavesProved) - if !isContainer { + for k := range merkle_path.CapContainer { + totalCapContainer[i][k] = typeConverters.LittleEndianUint8ToBigInt(merkle_path.CapContainer[k].Digest[:]) + } for j := range numOfLeavesProved { proof := merkle_path.Proofs[j] for z := range treeHeight { totalAuthPath[i][j][z] = typeConverters. - LittleEndianUint8ToBigInt(proof.AuthPath[treeHeight-1-z].KeccakDigest[:]) + LittleEndianUint8ToBigInt(proof.AuthPath[z].Digest[:]) } totalLeafSiblingHashes[i][j] = typeConverters. - LittleEndianUint8ToBigInt(proof.LeafSiblingHash.KeccakDigest[:]) + LittleEndianUint8ToBigInt(proof.LeafSiblingHash.Digest[:]) totalLeafIndexes[i][j] = uints.NewU64(proof.LeafIndex) for k := range hint.stirAnswers[i][j] { @@ -57,6 +65,7 @@ func newMerkle( LeafIndexes: totalLeafIndexes, LeafSiblingHashes: totalLeafSiblingHashes, AuthPaths: totalAuthPath, + CapContainer: totalCapContainer, } } diff --git a/recursive-verifier/app/circuit/skyscraper2.go b/recursive-verifier/app/circuit/skyscraper2.go new file mode 100644 index 00000000..bcf68edc --- /dev/null +++ b/recursive-verifier/app/circuit/skyscraper2.go @@ -0,0 +1,314 @@ +package circuit + +import ( + "math/big" +) + +// BN254 scalar field modulus +var bn254Modulus = new(big.Int) + +// SIGMA_INV constant for Skyscraper +var sigmaInv = new(big.Int) + +// Round constants for Skyscraper +var roundConstants [18]*big.Int + +func init() { + // BN254 modulus: 21888242871839275222246405745257275088548364400416034343698204186575808495617 + bn254Modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + + // SIGMA_INV: 9915499612839321149637521777990102151350674507940716049588462388200839649614 + sigmaInv.SetString("9915499612839321149637521777990102151350674507940716049588462388200839649614", 10) + + // Initialize round constants from the Rust implementation + rcHex := [][4]uint64{ + {0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000}, + {0x903c4324270bd744, 0x873125f708a7d269, 0x081dd27906c83855, 0x276b1823ea6d7667}, + {0x7ac8edbb4b378d71, 0xe29d79f3d99e2cb7, 0x751417914c1a5a18, 0x0cf02bd758a484a6}, + {0xfa7adc6769e5bc36, 0x1c3f8e297cca387d, 0x0eb7730d63481db0, 0x25b0e03f18ede544}, + {0x57847e652f03cfb7, 0x33440b9668873404, 0x955a32e849af80bc, 0x002882fcbe14ae70}, + {0x979231396257d4d7, 0x29989c3e1b37d3c1, 0x12ef02b47f1277ba, 0x039ad8571e2b7a9c}, + {0xb5b48465abbb7887, 0xa72a6bc5e6ba2d2b, 0x4cd48043712f7b29, 0x1142d5410fc1fc1a}, + {0x7ab2c156059075d3, 0x17cb3594047999b2, 0x44f2c93598f289f7, 0x1d78439f69bc0bec}, + {0x05d7a965138b8edb, 0x36ef35a3d55c48b1, 0x8ddfb8a1ac6f1628, 0x258588a508f4ff82}, + {0x1596fb9afccb49e9, 0x9a7367d69a09a95b, 0x9bc43f6984e4c157, 0x13087879d2f514fe}, + {0x295ccd233b4109fa, 0xe1d72f89ed868012, 0x2e9e1eea4bc88a8e, 0x17dadee898c45232}, + {0x9a8590b4aa1f486f, 0xb75834b430e9130e, 0xb8e90b1034d5de31, 0x295c6d1546e7f4a6}, + {0x850adcb74c6eb892, 0x07699ef305b92fc3, 0x4ef96a2ba1720f2d, 0x1288ca0e1d3ed446}, + {0x01960f9349d1b5ee, 0x8ccad30769371c69, 0xe5c81e8991c98662, 0x17563b4d1ae023f3}, + {0x6ba01e9476b32917, 0xa1cb0a3add977bc9, 0x86815a945815f030, 0x2869043be91a1eea}, + {0x81776c885511d976, 0x7475d34f47f414e7, 0x5d090056095d96cf, 0x14941f0aff59e79a}, + {0xbc40b4fd8fc8c034, 0xbb7142c3cce4fd48, 0x318356758a39005a, 0x1ce337a190f4379f}, + {0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000}, + } + + for i := range rcHex { + roundConstants[i] = limbsToBigInt(rcHex[i]) + } +} + +// Convert [4]uint64 limbs to big.Int (little-endian) +func limbsToBigInt(limbs [4]uint64) *big.Int { + result := new(big.Int) + + // limbs[0] + limbs[1]<<64 + limbs[2]<<128 + limbs[3]<<192 + result.SetUint64(limbs[0]) + + temp := new(big.Int) + temp.SetUint64(limbs[1]) + temp.Lsh(temp, 64) + result.Add(result, temp) + + temp.SetUint64(limbs[2]) + temp.Lsh(temp, 128) + result.Add(result, temp) + + temp.SetUint64(limbs[3]) + temp.Lsh(temp, 192) + result.Add(result, temp) + + return result.Mod(result, bn254Modulus) +} + +// Convert big.Int to [4]uint64 limbs +func bigIntToLimbs(val *big.Int) [4]uint64 { + // Ensure value is positive and reduced + v := new(big.Int).Set(val) + v.Mod(v, bn254Modulus) + + limbs := [4]uint64{} + mask := new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF) + + for i := 0; i < 4; i++ { + temp := new(big.Int).And(v, mask) + limbs[i] = temp.Uint64() + v.Rsh(v, 64) + } + + return limbs +} + +// Field modular addition +func fieldAdd(a, b *big.Int) *big.Int { + result := new(big.Int).Add(a, b) + return result.Mod(result, bn254Modulus) +} + +// Field modular multiplication +func fieldMul(a, b *big.Int) *big.Int { + result := new(big.Int).Mul(a, b) + return result.Mod(result, bn254Modulus) +} + +// Field squaring +func fieldSquare(a *big.Int) *big.Int { + return fieldMul(a, a) +} + +// S-box function for bar operation +func sbox(v byte) byte { + notV := ^v + rotLeft1 := (notV << 1) | (notV >> 7) + rotLeft2 := (v << 2) | (v >> 6) + rotLeft3 := (v << 3) | (v >> 5) + + xor := v ^ (rotLeft1 & rotLeft2 & rotLeft3) + return (xor << 1) | (xor >> 7) // rotate left by 1 +} + +// Bar function: cyclic rotate bytes + s-box +func bar(x *big.Int) *big.Int { + // Convert to 32 bytes (little-endian) + bytes := make([]byte, 32) + xBytes := x.Bytes() + + // x.Bytes() returns big-endian, we need little-endian + for i := 0; i < len(xBytes) && i < 32; i++ { + bytes[i] = xBytes[len(xBytes)-1-i] + } + + // Cyclic rotate by 16 bytes + rotated := make([]byte, 32) + copy(rotated[:16], bytes[16:]) + copy(rotated[16:], bytes[:16]) + + // Apply s-box to each byte + for i := range rotated { + rotated[i] = sbox(rotated[i]) + } + + // Convert back to big.Int (little-endian) + result := new(big.Int) + for i := 31; i >= 0; i-- { + result.Lsh(result, 8) + result.Or(result, new(big.Int).SetUint64(uint64(rotated[i]))) + } + + return result.Mod(result, bn254Modulus) +} + +// SS round (square-sigma round) +func ss(round int, l, r *big.Int) (*big.Int, *big.Int) { + // r += l^2 * sigma_inv + rc[round] + lSquared := fieldSquare(l) + term := fieldMul(lSquared, sigmaInv) + term = fieldAdd(term, roundConstants[round]) + r = fieldAdd(r, term) + + // swap(l, r) + l, r = r, l + + // r += l^2 * sigma_inv + rc[round+1] + lSquared = fieldSquare(l) + term = fieldMul(lSquared, sigmaInv) + term = fieldAdd(term, roundConstants[round+1]) + r = fieldAdd(r, term) + + // swap(l, r) + l, r = r, l + + return l, r +} + +// BB round (bar-bar round) +func bb(round int, l, r *big.Int) (*big.Int, *big.Int) { + // r += bar(l) + rc[round] + barL := bar(l) + r = fieldAdd(r, barL) + r = fieldAdd(r, roundConstants[round]) + + // swap(l, r) + l, r = r, l + + // r += bar(l) + rc[round+1] + barL = bar(l) + r = fieldAdd(r, barL) + r = fieldAdd(r, roundConstants[round+1]) + + // swap(l, r) + l, r = r, l + + return l, r +} + +// Permute function: 9 rounds of alternating ss and bb +func permute(l, r *big.Int) (*big.Int, *big.Int) { + l, r = ss(0, l, r) + l, r = ss(2, l, r) + l, r = ss(4, l, r) + l, r = bb(6, l, r) + l, r = ss(8, l, r) + l, r = bb(10, l, r) + l, r = ss(12, l, r) + l, r = ss(14, l, r) + l, r = ss(16, l, r) + return l, r +} + +// SkyscraperCompress: Main compression function +// Takes two field elements (as [4]uint64 limbs) and returns compressed hash +func SkyscraperCompress(left, right [4]uint64) [4]uint64 { + l := limbsToBigInt(left) + r := limbsToBigInt(right) + + t := new(big.Int).Set(l) + l, _ = permute(l, r) + result := fieldAdd(l, t) + + return bigIntToLimbs(result) +} + +// SkyscraperCompressFp256: Compress two Fp256 field elements +func SkyscraperCompressFp256(left, right Fp256) Digest { + leftLimbs := left.Limbs + rightLimbs := right.Limbs + + resultLimbs := SkyscraperCompress(leftLimbs, rightLimbs) + + // Convert result limbs to Digest (32 bytes, little-endian) + var digest Digest + for i := 0; i < 4; i++ { + limb := resultLimbs[i] + for j := 0; j < 8; j++ { + digest.Digest[i*8+j] = byte(limb & 0xFF) + limb >>= 8 + } + } + + return digest +} + +// HashLeafData: Hash multiple Fp256 elements to create a leaf digest +// This iteratively compresses pairs of elements +func HashLeafData(leafData []Fp256) Digest { + if len(leafData) == 0 { + panic("Cannot hash empty leaf data") + } + + // Start with first element + currentLimbs := leafData[0].Limbs + // Iteratively compress with remaining elements + for i := 1; i < len(leafData); i++ { + currentLimbs = SkyscraperCompress(currentLimbs, leafData[i].Limbs) + } + + var digest Digest + // Convert to Digest + for i := 0; i < 4; i++ { + limb := currentLimbs[i] + for j := 0; j < 8; j++ { + digest.Digest[i*8+j] = byte(limb & 0xFF) + limb >>= 8 + } + } + + return digest +} + +// HashTwoDigests: Hash two digests together +func HashTwoDigests(left, right Digest) Digest { + // Convert digests to [4]uint64 limbs (little-endian) + leftLimbs := [4]uint64{} + rightLimbs := [4]uint64{} + + for i := 0; i < 4; i++ { + var limb uint64 + for j := 0; j < 8; j++ { + limb |= uint64(left.Digest[i*8+j]) << (8 * j) + } + leftLimbs[i] = limb + } + + for i := 0; i < 4; i++ { + var limb uint64 + for j := 0; j < 8; j++ { + limb |= uint64(right.Digest[i*8+j]) << (8 * j) + } + rightLimbs[i] = limb + } + + resultLimbs := SkyscraperCompress(leftLimbs, rightLimbs) + + var digest Digest + for i := 0; i < 4; i++ { + limb := resultLimbs[i] + for j := 0; j < 8; j++ { + digest.Digest[i*8+j] = byte(limb & 0xFF) + limb >>= 8 + } + } + + return digest +} + +func DigestToFieldElement(d Digest) *big.Int { + var result = new(big.Int) + + for i := 31; i >= 0; i-- { + result.Lsh(result, 8) + result.Add(result, big.NewInt(int64(d.Digest[i]))) + } + + result.Mod(result, bn254Modulus) + return result +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 67bc53b4..2ba591d9 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -6,8 +6,8 @@ import ( ) // Common types -type KeccakDigest struct { - KeccakDigest [32]uint8 +type Digest struct { + Digest [32]uint8 } type Fp256 struct { @@ -24,6 +24,17 @@ type FullMultiPath[Digest any] struct { Proofs []Path[Digest] } +type FullMultiPathWithCapping[Digest any] struct { + Proofs []Path[Digest] + CapContainer []Digest +} + +type MultiIndexMerkleTreeProof[Digest any] struct { + Depth uint64 + Indices []uint64 + Proof []Digest +} + // WHIR specific types type WHIRConfig struct { NRounds int `json:"n_rounds"` @@ -68,19 +79,12 @@ type InitialSumcheckData struct { InitialCombinationRandomness []frontend.Variable } -// Merkle specific types -type MerklePaths struct { - Leaves [][][]frontend.Variable - LeafIndexes [][]uints.U64 - LeafSiblingHashes [][][]uints.U8 - AuthPaths [][][][]uints.U8 -} - type Merkle struct { Leaves [][][]frontend.Variable LeafIndexes [][]uints.U64 LeafSiblingHashes [][]frontend.Variable AuthPaths [][][]frontend.Variable + CapContainer [][]frontend.Variable } // Other types @@ -96,7 +100,6 @@ type Config struct { LogANumTerms int `json:"log_a_num_terms"` IOPattern string `json:"io_pattern"` Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` } @@ -107,7 +110,7 @@ type Hints struct { } type Hint struct { - merklePaths []FullMultiPath[KeccakDigest] + merklePaths []FullMultiPathWithCapping[Digest] stirAnswers [][][]Fp256 } diff --git a/recursive-verifier/app/circuit/utilities.go b/recursive-verifier/app/circuit/utilities.go index 65f76d69..305333cd 100644 --- a/recursive-verifier/app/circuit/utilities.go +++ b/recursive-verifier/app/circuit/utilities.go @@ -242,7 +242,7 @@ func consumeFront[T any](slice *[]T) T { return head } -func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[KeccakDigest], stir_answers *[][][]Fp256) ZKHint { +func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPathWithCapping[Digest], stir_answers *[][][]Fp256) ZKHint { var zkHint ZKHint if len(*merkle_paths) > 0 && len(*stir_answers) > 0 { @@ -251,7 +251,7 @@ func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Keccak zkHint.firstRoundMerklePaths = FirstRoundHint{ path: Hint{ - merklePaths: []FullMultiPath[KeccakDigest]{firstRoundMerklePath}, + merklePaths: []FullMultiPathWithCapping[Digest]{firstRoundMerklePath}, stirAnswers: [][][]Fp256{firstRoundStirAnswers}, }, expectedStirAnswers: firstRoundStirAnswers, @@ -260,7 +260,7 @@ func consumeWhirData(whirConfig WHIRConfig, merkle_paths *[]FullMultiPath[Keccak expectedRounds := whirConfig.NRounds - var remainingMerklePaths []FullMultiPath[KeccakDigest] + var remainingMerklePaths []FullMultiPathWithCapping[Digest] var remainingStirAnswers [][][]Fp256 for i := 0; i < expectedRounds && len(*merkle_paths) > 0 && len(*stir_answers) > 0; i++ { diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index b43ec3ad..6853240f 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -65,34 +65,14 @@ func RunZKWhir( ) (totalFoldingRandomness []frontend.Variable, err error) { initialOODs := oodAnswers(api, initialOODAnswers, batchingRandomness) - // batchSizeLen := whirParams.BatchSize initialSumcheckData, lastEval, initialSumcheckFoldingRandomness, err := initialSumcheck(api, arthur, batchingRandomness, initialOODQueries, initialOODs, whirParams, linearStatementEvaluations) if err != nil { return } - copyOfFirstLeaves := make([][][]frontend.Variable, len(firstRound.Leaves)) - for i := range len(firstRound.Leaves) { - copyOfFirstLeaves[i] = make([][]frontend.Variable, len(firstRound.Leaves[i])) - for j := range len(firstRound.Leaves[i]) { - copyOfFirstLeaves[i][j] = make([]frontend.Variable, len(firstRound.Leaves[i][j])) - for k := range len(firstRound.Leaves[i][j]) { - copyOfFirstLeaves[i][j][k] = firstRound.Leaves[i][j][k] - } - } - } - - roundAnswers := make([][][]frontend.Variable, len(circuit.Leaves)+1) - foldSize := 1 << whirParams.FoldingFactorArray[0] collapsed := rlcBatchedLeaves(api, firstRound.Leaves[0], foldSize, whirParams.BatchSize, batchingRandomness) - roundAnswers[0] = collapsed - - for i := range len(circuit.Leaves) { - roundAnswers[i+1] = circuit.Leaves[i] - } - computedFold := computeFold(collapsed, initialSumcheckFoldingRandomness, api) mainRoundData := generateEmptyMainRoundData(whirParams) @@ -130,7 +110,7 @@ func RunZKWhir( if err != nil { return } - err = verifyMerkleTreeProofs(api, uapi, sc, firstRound.LeafIndexes[0], firstRound.Leaves[0], firstRound.LeafSiblingHashes[0], firstRound.AuthPaths[0], rootHashes) + err = verifyMerkleTreeProofs(api, uapi, sc, firstRound.LeafIndexes[0], firstRound.Leaves[0], firstRound.LeafSiblingHashes[0], firstRound.AuthPaths[0], firstRound.CapContainer[0], rootHashes) if err != nil { return } @@ -143,7 +123,7 @@ func RunZKWhir( if err != nil { return } - err = verifyMerkleTreeProofs(api, uapi, sc, circuit.LeafIndexes[r-1], roundAnswers[r], circuit.LeafSiblingHashes[r-1], circuit.AuthPaths[r-1], rootHashList[r-1]) + err = verifyMerkleTreeProofs(api, uapi, sc, circuit.LeafIndexes[r-1], circuit.Leaves[r-1], circuit.LeafSiblingHashes[r-1], circuit.AuthPaths[r-1], circuit.CapContainer[r-1], rootHashList[r-1]) if err != nil { return } @@ -217,144 +197,6 @@ func RunZKWhir( return totalFoldingRandomness, nil } -//nolint:unused -func runWhir( - api frontend.API, - arthur gnarkNimue.Arthur, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - whirParams WHIRParams, - linearStatementEvaluations []frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, -) (totalFoldingRandomness []frontend.Variable, err error) { - if err = fillInAndVerifyRootHash(0, api, uapi, sc, circuit, arthur); err != nil { - return - } - - initialOODQueries, initialOODAnswers, tempErr := fillInOODPointsAndAnswers(whirParams.CommittmentOODSamples, arthur) - if tempErr != nil { - err = tempErr - return - } - - initialCombinationRandomness, tempErr := GenerateCombinationRandomness(api, arthur, whirParams.CommittmentOODSamples+len(linearStatementEvaluations)) - if tempErr != nil { - err = tempErr - return - } - - OODAnswersAndStatmentEvaluations := append(initialOODAnswers, linearStatementEvaluations...) - lastEval := utilities.DotProduct(api, initialCombinationRandomness, OODAnswersAndStatmentEvaluations) - - initialSumcheckFoldingRandomness, lastEval, tempErr := runWhirSumcheckRounds(api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) - if tempErr != nil { - err = tempErr - return - } - - initialData := InitialSumcheckData{ - InitialOODQueries: initialOODQueries, - InitialCombinationRandomness: initialCombinationRandomness, - } - - computedFold := computeFold(circuit.Leaves[0], initialSumcheckFoldingRandomness, api) - - mainRoundData := generateEmptyMainRoundData(whirParams) - - expDomainGenerator := utilities.Exponent(api, uapi, whirParams.StartingDomainBackingDomainGenerator, uints.NewU64(uint64(1< 0 { + for i := ((len(capContainer) / 2) - 1); i > 0; i-- { + supposedHash := sc.CompressV2(capContainer[2*i], capContainer[2*i+1]) + actualHash := api.Select(api.IsZero(capContainer[2*i]), capContainer[i], supposedHash) + api.AssertIsEqual(actualHash, capContainer[i]) + } + } + api.AssertIsEqual(rootHash, capContainer[1]) + + capDepth := bits.Len(uint(len(capContainer)/2)) - 1 + cappedNodesLUT := logderivlookup.New(api) + + for i := (len(capContainer) / 2); i < len(capContainer); i++ { + cappedNodesLUT.Insert(capContainer[i]) + } + numOfLeavesProved := len(leaves) + + trimmedTreeHeight := 0 + if len(authPaths) > 0 { + trimmedTreeHeight = len(authPaths[0]) + } for i := range numOfLeavesProved { - treeHeight := len(authPaths[i]) + 1 + treeHeight := trimmedTreeHeight + 1 + capDepth leafIndexBits := api.ToBinary(uapi.ToValue(leafIndexes[i]), treeHeight) - leafSiblingHash := leafSiblingHashes[i] + rootIndex := api.FromBinary(leafIndexBits[treeHeight-capDepth:]...) + searchRes := cappedNodesLUT.Lookup(rootIndex) + cappedNodeHash := searchRes[0] + leafSiblingHash := leafSiblingHashes[i] claimedLeafHash := sc.CompressV2(leaves[i][0], leaves[i][1]) for x := range len(leaves[i]) - 2 { claimedLeafHash = sc.CompressV2(claimedLeafHash, leaves[i][x+2]) } - dir := leafIndexBits[0] xLeftChild := api.Select(dir, leafSiblingHash, claimedLeafHash) xRightChild := api.Select(dir, claimedLeafHash, leafSiblingHash) currentHash := sc.CompressV2(xLeftChild, xRightChild) - - for level := 1; level < treeHeight; level++ { + for level := 1; level < treeHeight-capDepth; level++ { indexBit := leafIndexBits[level] siblingHash := authPaths[i][level-1] @@ -40,7 +63,7 @@ func verifyMerkleTreeProofs(api frontend.API, uapi *uints.BinaryField[uints.U64] currentHash = sc.CompressV2(left, right) } - api.AssertIsEqual(currentHash, rootHash) + api.AssertIsEqual(currentHash, cappedNodeHash) } return nil } @@ -157,26 +180,6 @@ func computeWPoly( return value } -//nolint:unused -func fillInAndVerifyRootHash( - roundNum int, - api frontend.API, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - arthur gnarkNimue.Arthur, -) error { - rootHash := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(rootHash); err != nil { - return err - } - err := verifyMerkleTreeProofs(api, uapi, sc, circuit.LeafIndexes[roundNum], circuit.Leaves[roundNum], circuit.LeafSiblingHashes[roundNum], circuit.AuthPaths[roundNum], rootHash[0]) - if err != nil { - return err - } - return nil -} - func computeFold(leaves [][]frontend.Variable, foldingRandomness []frontend.Variable, api frontend.API) []frontend.Variable { computedFold := make([]frontend.Variable, len(leaves)) for j := range leaves { diff --git a/recursive-verifier/app/keccakSponge/keccakSponge.go b/recursive-verifier/app/keccakSponge/keccakSponge.go deleted file mode 100644 index 84b1f910..00000000 --- a/recursive-verifier/app/keccakSponge/keccakSponge.go +++ /dev/null @@ -1,82 +0,0 @@ -package keccakSponge - -import ( - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/uints" - "github.com/consensys/gnark/std/permutation/keccakf" -) - -type Digest struct { - api frontend.API - uapi *uints.BinaryField[uints.U64] - state [25]uints.U64 - absorb_pos int - squeeze_pos int -} - -func NewKeccak(api frontend.API) (*Digest, error) { - uapi, err := uints.New[uints.U64](api) - if err != nil { - return nil, err - } - return &Digest{ - api: api, - uapi: uapi, - state: newState(), - absorb_pos: 0, - squeeze_pos: 136, - }, nil -} - -func NewKeccakWithTag(api frontend.API, tag []frontend.Variable) (*Digest, error) { - d, _ := NewKeccak(api) - for i := 136; i < 136+len(tag); i++ { - d.state[i/8][i%8].Val = tag[i-136] - } - - return d, nil -} - -func (d *Digest) Absorb(in []frontend.Variable) { - u8Arr := make([]uints.U8, len(in)) - for i := range in { - u8Arr[i].Val = in[i] - } - - for _, inputByte := range u8Arr { - if d.absorb_pos == 136 { - d.state = keccakf.Permute(d.uapi, d.state) - d.absorb_pos = 0 - } - d.state[d.absorb_pos/8][d.absorb_pos%8] = inputByte - d.absorb_pos++ - } - - d.squeeze_pos = 136 -} - -func (d *Digest) AbsorbQuadraticPolynomial(in [][]frontend.Variable) { - for i := range in { - d.Absorb(in[i]) - } -} - -func (d *Digest) Squeeze(len int) (result []frontend.Variable) { - for i := 0; i < len; i++ { - if d.squeeze_pos == 136 { - d.squeeze_pos = 0 - d.absorb_pos = 0 - d.state = keccakf.Permute(d.uapi, d.state) - } - result = append(result, d.state[d.squeeze_pos/8][d.squeeze_pos%8].Val) - d.squeeze_pos++ - } - return result -} - -func newState() (state [25]uints.U64) { - for i := range state { - state[i] = uints.NewU64(0) - } - return -} diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 968ca283..7c5354bf 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -23,8 +23,6 @@ pub struct GnarkConfig { pub io_pattern: String, /// transcript in byte form pub transcript: Vec, - /// length of the transcript - pub transcript_len: usize, } #[derive(Debug, Serialize, Deserialize)] @@ -117,7 +115,6 @@ pub fn gnark_parameters( log_a_num_terms: a_num_terms, io_pattern: String::from_utf8(io.as_bytes().to_vec()).unwrap(), transcript: transcript.to_vec(), - transcript_len: transcript.to_vec().len(), } }