diff --git a/Cargo.lock b/Cargo.lock index 9f57867be..1a55db8c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1061,7 +1061,7 @@ dependencies = [ [[package]] name = "ceno_crypto_primitives" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#6e90bc85bceefe09003f84ca5cc1afbe2911a2db" dependencies = [ "ceno_syscall", "elliptic-curve", @@ -1168,9 +1168,11 @@ name = "ceno_rt" version = "0.1.0" dependencies = [ "ceno_serde", + "ceno_syscall", "getrandom 0.2.16", "getrandom 0.3.2", "serde", + "tiny-keccak", ] [[package]] @@ -1193,7 +1195,7 @@ dependencies = [ [[package]] name = "ceno_syscall" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#6e90bc85bceefe09003f84ca5cc1afbe2911a2db" [[package]] name = "ceno_zkvm" diff --git a/ceno_cli/example/src/main.rs b/ceno_cli/example/src/main.rs index fc88e6613..aa6f5f655 100644 --- a/ceno_cli/example/src/main.rs +++ b/ceno_cli/example/src/main.rs @@ -27,5 +27,5 @@ fn main() { panic!(); } - ceno_rt::commit(&cnt_primes); + ceno_rt::commit(&cnt_primes.to_le_bytes()); } diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 584702bd4..15e736507 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -169,25 +169,9 @@ impl CenoOptions { self.heap_size.next_multiple_of(WORD_SIZE as u32) } - /// Read the public io into ceno stdin - pub fn read_public_io(&self) -> anyhow::Result> { - if let Some(public_io) = &self.public_io { - // if vector contains only one element, write it as a raw `u32` - // otherwise, write the entire vector - // in both cases, convert the resulting `CenoStdin` into a `Vec` - if public_io.len() == 1 { - CenoStdin::default() - .write(&public_io[0]) - .map(|stdin| Into::>::into(&*stdin)) - } else { - CenoStdin::default() - .write(public_io) - .map(|stdin| Into::>::into(&*stdin)) - } - .context("failed to get public_io".to_string()) - } else { - Ok(vec![]) - } + /// Read raw public-io words; digesting happens later in the zkVM pipeline. + pub fn read_public_io(&self) -> Vec { + self.public_io.clone().unwrap_or_default() } /// Read the hints @@ -367,14 +351,14 @@ fn run_elf_inner< options.max_cycle_per_shard, ); - let public_io = options - .read_public_io() - .context("failed to read public io")?; + let public_io_digest_input = options.read_public_io(); + let public_io_digest = public_io_words_to_digest_words(&public_io_digest_input); + tracing::debug!("public io digest words: {:?}", public_io_digest); let public_io_size = options.public_io_size; assert!( - public_io.len() <= public_io_size as usize / WORD_SIZE, + public_io_digest_input.len() <= public_io_size as usize / WORD_SIZE, "require pub io length {} < max public_io_size {}", - public_io.len(), + public_io_digest_input.len(), public_io_size as usize / WORD_SIZE ); @@ -416,7 +400,7 @@ fn run_elf_inner< platform, multi_prover, &hints, - &public_io, + &public_io_digest_input, options.max_steps, checkpoint, options.shard_id.map(|v| v as usize), diff --git a/ceno_cli/src/sdk.rs b/ceno_cli/src/sdk.rs index cbcb601b9..ce31fcd8b 100644 --- a/ceno_cli/src/sdk.rs +++ b/ceno_cli/src/sdk.rs @@ -153,8 +153,16 @@ where shard_id: Option, ) -> Vec> { if let Some(zkvm_prover) = self.zkvm_prover.as_ref() { - let init_full_mem = zkvm_prover.setup_init_mem(&Vec::from(&hints), &Vec::from(&pub_io)); - run_e2e_proof::(zkvm_prover, &init_full_mem, max_steps, false, shard_id) + let public_io_words = Vec::from(&pub_io); + let init_full_mem = zkvm_prover.setup_init_mem(&Vec::from(&hints), &public_io_words); + run_e2e_proof::( + zkvm_prover, + &init_full_mem, + &public_io_words, + max_steps, + false, + shard_id, + ) } else { panic!("ZKVMProver is not initialized") } diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..8bf8eacb1 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -31,7 +31,7 @@ pub mod disassemble; mod syscalls; pub use syscalls::{ BLS12381_ADD, BLS12381_DECOMPRESS, BLS12381_DOUBLE, BN254_ADD, BN254_DOUBLE, BN254_FP_ADD, - BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, + BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, PubIoCommitSpec, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, UINT256_MUL, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..16e7eddb0 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -13,7 +13,6 @@ use crate::addr::{Addr, RegIdx}; pub struct Platform { pub rom: Range, pub prog_data: Arc>, - pub public_io: Range, pub stack: Range, pub heap: Range, @@ -34,7 +33,7 @@ impl Display for Platform { write!( f, "Platform {{ rom: {:#x}..{:#x}, prog_data: {:#x}..{:#x}, stack: {:#x}..{:#x}, heap: {:#x}..{:#x}, \ - public_io: {:#x}..{:#x}, hints: {:#x}..{:#x}, unsafe_ecall_nop: {} }}", + hints: {:#x}..{:#x}, unsafe_ecall_nop: {} }}", self.rom.start, self.rom.end, prog_data @@ -49,8 +48,6 @@ impl Display for Platform { self.stack.end, self.heap.start, self.heap.end, - self.public_io.start, - self.public_io.end, self.hints.start, self.hints.end, self.unsafe_ecall_nop @@ -81,21 +78,21 @@ impl Display for Platform { // │ STACK (≈128 MB, grows downward) // │ 0x1800_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1800_0000 (stack base / pubio end) +// ├───────────────────────────── 0x1800_0000 (stack base) // │ -// │ PUBLIC I/O (128 MB) -// │ 0x1000_0000 .. 0x1800_0000 +// │ STACK (stack-only memory window, includes reserved low-address area +// │ previously used for PUBLIC I/O) +// │ 0x1000_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1000_0000 (pubio base / rom end) +// ├───────────────────────────── 0x1000_0000 (stack base / rom end) // │ // │ ROM / TEXT / RODATA (128 MB) // │ 0x0800_0000 .. 0x1000_0000 // │ // └───────────────────────────── 0x8000_0000 (rom base) pub static CENO_PLATFORM: Lazy = Lazy::new(|| Platform { - rom: 0x0800_0000..0x1000_0000, // 128 MB - public_io: 0x1000_0000..0x1800_0000, // 128 MB - stack: 0x1800_0000..0x2000_4000, // stack grows downward 128MB, 0x4000 reserved for debug io. + rom: 0x0800_0000..0x1000_0000, // 128 MB + stack: 0x1000_0000..0x2000_4000, // stack grows downward, 0x4000 reserved for debug io. // we make hints start from 0x2800_0000 thus reserve a 128MB gap for debug io // at the end of stack hints: 0x2800_0000..0x3000_0000, // 128 MB @@ -123,10 +120,6 @@ impl Platform { self.stack.contains(&addr) || self.heap.contains(&addr) || self.is_prog_data(addr) } - pub fn is_pub_io(&self, addr: Addr) -> bool { - self.public_io.contains(&addr) - } - pub fn is_hints(&self, addr: Addr) -> bool { self.hints.contains(&addr) } @@ -155,7 +148,7 @@ impl Platform { } pub fn can_write(&self, addr: Addr) -> bool { - self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) + self.is_ram(addr) || self.is_hints(addr) } // Environment calls. @@ -187,13 +180,7 @@ impl Platform { /// Validate the platform configuration, range shall not overlap. pub fn validate(&self) -> bool { - let mut ranges = [ - &self.rom, - &self.stack, - &self.heap, - &self.public_io, - &self.hints, - ]; + let mut ranges = [&self.rom, &self.stack, &self.heap, &self.hints]; ranges.sort_by_key(|r| r.start); for i in 0..ranges.len() - 1 { if ranges[i].end > ranges[i + 1].start { diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..368ea88cf 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -4,6 +4,7 @@ use anyhow::Result; pub mod bn254; pub mod keccak_permute; pub mod phantom; +pub mod pubio_commit; pub mod secp256k1; pub(crate) mod secp256r1; pub mod sha256; @@ -14,9 +15,11 @@ pub mod uint256; pub use ceno_syscall::{ BLS12381_ADD, BLS12381_DECOMPRESS, BLS12381_DOUBLE, BN254_ADD, BN254_DOUBLE, BN254_FP_ADD, BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, PHANTOM_LOG_PC_CYCLE, - SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, - SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, UINT256_MUL, + PUB_IO_COMMIT, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, + SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, + UINT256_MUL, }; +pub use pubio_commit::PubIoCommitSpec; pub trait SyscallSpec { const NAME: &'static str; @@ -49,6 +52,7 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< BN254_FP2_ADD => Ok(bn254::bn254_fp2_add(vm)), BN254_FP2_MUL => Ok(bn254::bn254_fp2_mul(vm)), UINT256_MUL => Ok(uint256::uint256_mul(vm)), + PUB_IO_COMMIT => Ok(pubio_commit::pubio_commit(vm)), // phantom syscall PHANTOM_LOG_PC_CYCLE => Ok(phantom::log_pc_cycle(vm)), diff --git a/ceno_emul/src/syscalls/pubio_commit.rs b/ceno_emul/src/syscalls/pubio_commit.rs new file mode 100644 index 000000000..f21a2109d --- /dev/null +++ b/ceno_emul/src/syscalls/pubio_commit.rs @@ -0,0 +1,33 @@ +use crate::{Change, EmuContext, Platform, Tracer, VMState, WriteOp, utils::MemoryView}; + +use super::{SyscallEffects, SyscallSpec, SyscallWitness}; + +const PUBIO_COMMIT_WORDS: usize = 8; + +pub struct PubIoCommitSpec; +impl SyscallSpec for PubIoCommitSpec { + const NAME: &'static str = "PUB_IO_COMMIT"; + const REG_OPS_COUNT: usize = 1; + const MEM_OPS_COUNT: usize = PUBIO_COMMIT_WORDS; + const CODE: u32 = ceno_syscall::PUB_IO_COMMIT; +} + +/// Trace the PUB_IO_COMMIT syscall by reading 8 digest words from guest memory. +pub fn pubio_commit(vm: &VMState) -> SyscallEffects { + let digest_ptr = vm.peek_register(Platform::reg_arg0()); + + let reg_ops = vec![WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(digest_ptr, digest_ptr), + 0, + )]; + + let digest_view = MemoryView::<_, PUBIO_COMMIT_WORDS>::new(vm, digest_ptr); + let mem_ops = digest_view.mem_ops().to_vec(); + + assert_eq!(mem_ops.len(), PubIoCommitSpec::MEM_OPS_COUNT); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} diff --git a/ceno_host/src/lib.rs b/ceno_host/src/lib.rs index 719288ece..cfe9a9d1d 100644 --- a/ceno_host/src/lib.rs +++ b/ceno_host/src/lib.rs @@ -115,7 +115,7 @@ pub fn run( platform: Platform, elf: &[u8], hints: &CenoStdin, - public_io: Option<&CenoStdin>, + _public_io: Option<&CenoStdin>, ) -> Vec> { let program = Program::load_elf(elf, u32::MAX).unwrap(); let platform = Platform { @@ -124,9 +124,7 @@ pub fn run( }; let hints: Vec = hints.into(); - let pubio: Vec = public_io.map(|c| c.into()).unwrap_or_default(); let hints_range = platform.hints.clone(); - let pubio_range = platform.public_io.clone(); let mut state = VMState::new(platform, Arc::new(program)); @@ -134,10 +132,6 @@ pub fn run( state.init_memory(addr.into(), value); } - for (addr, value) in zip(pubio_range.iter_addresses(), pubio) { - state.init_memory(addr.into(), value); - } - state .iter_until_halt() .collect::>>() diff --git a/ceno_recursion/src/aggregation/internal.rs b/ceno_recursion/src/aggregation/internal.rs index bf0fdfac2..568f1c17b 100644 --- a/ceno_recursion/src/aggregation/internal.rs +++ b/ceno_recursion/src/aggregation/internal.rs @@ -128,7 +128,7 @@ impl NonLeafVerifierVariables { // let expected_last_shard_id = Usize::uninit(builder); // builder.assign(&expected_last_shard_id, pv.len() - Usize::from(1)); - // let shard_id_fs = builder.get(&shard_raw_pi, SHARD_ID_IDX); + // let shard_id_fs = builder.get(&shard_pi, SHARD_ID_IDX); // let shard_id_f = builder.get(&shard_id_fs, 0); // let shard_id = Usize::Var(builder.cast_felt_to_var(shard_id_f)); // builder.assert_usize_eq(expected_last_shard_id, shard_id); diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index d17829fe9..8d3706f8d 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -265,7 +265,7 @@ impl CenoAggregationProver { .collect(); let user_public_values: Vec = zkvm_proof_inputs .iter() - .flat_map(|p| p.raw_pi.iter().flat_map(|v| v.clone()).collect::>()) + .flat_map(|p| p.pi.to_vec()) .collect(); let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs); @@ -398,7 +398,7 @@ impl CenoLeafVmVerifierConfig { builder.cycle_tracker_start("Verify Ceno ZKVM Proof"); let zkvm_proof = ceno_leaf_input.proof; - let raw_pi = zkvm_proof.raw_pi.clone(); + let pi = zkvm_proof.pi.clone(); let _calculated_shard_ec_sum = verify_zkvm_proof(&mut builder, zkvm_proof, &self.vk); builder.cycle_tracker_end("Verify Ceno ZKVM Proof"); @@ -409,19 +409,10 @@ impl CenoLeafVmVerifierConfig { builder.assign(&stark_pvs.app_commit[i], F::ZERO); } - let pv = &raw_pi; - let init_pc = { - let arr = builder.get(pv, INIT_PC_IDX); - builder.get(&arr, 0) - }; - let end_pc = { - let arr = builder.get(pv, END_PC_IDX); - builder.get(&arr, 0) - }; - let exit_code = { - let arr = builder.get(pv, EXIT_CODE_IDX); - builder.get(&arr, 0) - }; + let pv = π + let init_pc = builder.get(pv, INIT_PC_IDX); + let end_pc = builder.get(pv, END_PC_IDX); + let exit_code = builder.get(pv, EXIT_CODE_IDX); builder.assign(&stark_pvs.connector.initial_pc, init_pc); builder.assign(&stark_pvs.connector.final_pc, end_pc); builder.assign(&stark_pvs.connector.exit_code, exit_code); diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 08c76dbe0..78f4c3c99 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -71,9 +71,7 @@ pub fn decompose_prefixed_layer_bits(n: usize) -> (Vec, Vec>) { #[derive(DslVariable, Clone)] pub struct ZKVMProofInputVariable { pub shard_id: Usize, - pub raw_pi: Array>>, - pub raw_pi_num_variables: Array>, - pub pi_evals: Array>, + pub pi: Array>, pub chip_proofs: Array>>, pub max_num_var: Var, pub max_width: Var, @@ -95,9 +93,7 @@ pub struct TowerProofInputVariable { pub(crate) struct ZKVMProofInput { pub shard_id: usize, - pub raw_pi: Vec>, - // Evaluation of raw_pi. - pub pi_evals: Vec, + pub pi: Vec, pub chip_proofs: BTreeMap, pub witin_commit: BasefoldCommitment, pub opening_proof: BasefoldProof, @@ -109,6 +105,8 @@ impl ZKVMProofInput { zkvm_proof: ZKVMProof, vk: &ZKVMVerifyingKey, ) -> Self { + let pi = zkvm_proof.public_values.iter_field::().collect_vec(); + let mut chip_witin_num_vars: HashMap = HashMap::new(); // (chip_id, (num_witin, num_fixed)) let mut chip_indices = zkvm_proof .chip_proofs @@ -136,8 +134,7 @@ impl ZKVMProofInput { ZKVMProofInput { shard_id, - raw_pi: zkvm_proof.raw_pi, - pi_evals: zkvm_proof.pi_evals, + pi, chip_proofs: zkvm_proof .chip_proofs .into_iter() @@ -168,9 +165,7 @@ impl Hintable for ZKVMProofInput { fn read(builder: &mut Builder) -> Self::HintVariable { let shard_id = Usize::Var(usize::read(builder)); - let raw_pi = Vec::>::read(builder); - let raw_pi_num_variables = Vec::::read(builder); - let pi_evals = Vec::::read(builder); + let pi = Vec::::read(builder); builder.cycle_tracker_start("read chip proofs"); let chip_proofs = Vec::::read(builder); builder.cycle_tracker_end("read chip proofs"); @@ -186,9 +181,7 @@ impl Hintable for ZKVMProofInput { ZKVMProofInputVariable { shard_id, - raw_pi, - raw_pi_num_variables, - pi_evals, + pi, chip_proofs, max_num_var, max_width, @@ -201,11 +194,6 @@ impl Hintable for ZKVMProofInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - let raw_pi_num_variables: Vec = self - .raw_pi - .iter() - .map(|v| ceil_log2(v.len().next_power_of_two())) - .collect(); let witin_num_vars = self .chip_proofs .iter() @@ -217,21 +205,21 @@ impl Hintable for ZKVMProofInput { .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .map(|proof| proof.wits_in_evals.len().max(1)) + .map(|proof| proof.num_witin.max(1)) .collect::>(); let fixed_num_vars = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) + .filter(|proof| proof.num_fixed > 0) .map(|proof| proof.num_vars) .collect::>(); let fixed_max_widths = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) - .map(|proof| proof.fixed_in_evals.len()) + .filter(|proof| proof.num_fixed > 0) + .map(|proof| proof.num_fixed) .collect::>(); let max_num_var = witin_num_vars .iter() @@ -263,9 +251,7 @@ impl Hintable for ZKVMProofInput { let fixed_perm = get_perm(fixed_num_vars); stream.extend(>::write(&self.shard_id)); - stream.extend(self.raw_pi.write()); - stream.extend(raw_pi_num_variables.write()); - stream.extend(self.pi_evals.write()); + stream.extend(self.pi.write()); stream.extend(vec![vec![F::from_canonical_usize(self.chip_proofs.len())]]); for proofs in self.chip_proofs.values() { stream.extend(proofs.write()); @@ -403,9 +389,6 @@ pub struct ZKVMChipProofInput { pub ecc_proof: EccQuarkProofInput, pub num_instances: Vec, - - pub wits_in_evals: Vec, - pub fixed_in_evals: Vec, } impl VecAutoHintable for ZKVMChipProofInput {} @@ -499,8 +482,6 @@ impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { EccQuarkProofInput::dummy() }, num_instances: p.num_instances, - wits_in_evals: p.wits_in_evals, - fixed_in_evals: p.fixed_in_evals, } } } @@ -531,9 +512,6 @@ pub struct ZKVMChipProofInputVariable { pub num_instances: Array>, pub n_inst_0_bit_decomps: Array>, pub n_inst_1_bit_decomps: Array>, - - pub fixed_in_evals: Array>, - pub wits_in_evals: Array>, } impl Hintable for ZKVMChipProofInput { type HintVariable = ZKVMChipProofInputVariable; @@ -571,11 +549,6 @@ impl Hintable for ZKVMChipProofInput { let n_inst_0_bit_decomps = Vec::::read(builder); let n_inst_1_bit_decomps = Vec::::read(builder); - builder.cycle_tracker_start("read wit/fixed evals"); - let fixed_in_evals = Vec::::read(builder); - let wits_in_evals = Vec::::read(builder); - builder.cycle_tracker_end("read wit/fixed evals"); - ZKVMChipProofInputVariable { idx, idx_felt, @@ -597,8 +570,6 @@ impl Hintable for ZKVMChipProofInput { num_instances, n_inst_0_bit_decomps, n_inst_1_bit_decomps, - fixed_in_evals, - wits_in_evals, } } @@ -674,9 +645,6 @@ impl Hintable for ZKVMChipProofInput { stream.extend(n_inst_0_bit_decomps.write()); stream.extend(n_inst_1_bit_decomps.write()); - stream.extend(self.fixed_in_evals.write()); - stream.extend(self.wits_in_evals.write()); - stream } } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index e0d823814..696b4ca8b 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -26,7 +26,10 @@ use crate::{ SepticExtensionVariable, SepticPointVariable, SumcheckLayerProofVariable, }, }; -use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; +use ceno_zkvm::{ + instructions::riscv::constants::{END_CYCLE_IDX, END_PC_IDX, INIT_CYCLE_IDX, INIT_PC_IDX}, + structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}, +}; use ff_ext::BabyBearExt4; use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; @@ -104,23 +107,13 @@ pub fn verify_zkvm_proof>( let prod_w: Ext = builder.constant(C::EF::ONE); let logup_sum: Ext = builder.constant(C::EF::ZERO); - iter_zip!(builder, zkvm_proof_input.raw_pi).for_each(|ptr_vec, builder| { - let v = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &v); - }); - - iter_zip!(builder, zkvm_proof_input.raw_pi, zkvm_proof_input.pi_evals).for_each( - |ptr_vec, builder| { - let raw = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); - let eval = builder.iter_ptr_get(&zkvm_proof_input.pi_evals, ptr_vec[1]); - let raw0 = builder.get(&raw, 0); - - builder.if_eq(raw.len(), Usize::from(1)).then(|builder| { - let raw0_ext = builder.ext_from_base_slice(&[raw0]); - builder.assert_ext_eq(raw0_ext, eval); - }); - }, - ); + for (_, circuit_vk) in vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { + let raw = builder.get(&zkvm_proof_input.pi, instance_value.0); + // Match native verifier transcript behavior: append base-field PI element directly. + challenger.observe(builder, raw); + } + } builder .if_eq(zkvm_proof_input.shard_id.clone(), Usize::from(0)) @@ -273,14 +266,6 @@ pub fn verify_zkvm_proof>( // fork transcript to support chip concurrently proved let mut chip_challenger = clone_challenger_state(builder, &challenger); challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index); - builder.assert_usize_eq( - chip_proof.wits_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_witin()), - ); - builder.assert_usize_eq( - chip_proof.fixed_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_fixed()), - ); builder.assert_usize_eq( chip_proof.rw_out_evals.length.clone(), Usize::from( @@ -293,37 +278,35 @@ pub fn verify_zkvm_proof>( ); chip_challenger.observe(builder, chip_proof.idx_felt); - if !circuit_vk.get_cs().is_with_lk_table() { - // getting the number of dummy padding item that we used in this opcode circuit - let num_lks: Var = - builder.eval(C::N::from_canonical_usize(chip_vk.get_cs().num_lks())); - - // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let next_pow2_instance: Var = - pow_2(builder, chip_proof.log2_num_instances.get_var()); - let num_padded_instance: Var = - builder.eval(next_pow2_instance - chip_proof.sum_num_instances.clone()); - let rotation_var: Var = builder.constant(C::N::from_canonical_usize( - 1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0), + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks: Var = + builder.eval(C::N::from_canonical_usize(chip_vk.get_cs().num_lks())); + + // each padding instance contribute to (2^rotation_vars) dummy lookup padding + let next_pow2_instance: Var = + pow_2(builder, chip_proof.log2_num_instances.get_var()); + let num_padded_instance: Var = + builder.eval(next_pow2_instance - chip_proof.sum_num_instances.clone()); + let rotation_var: Var = builder.constant(C::N::from_canonical_usize( + 1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0), + )); + let rotation_subgroup_size: Var = + builder.constant(C::N::from_canonical_usize( + circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0), )); - let rotation_subgroup_size: Var = - builder.constant(C::N::from_canonical_usize( - circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0), - )); - builder.assign(&num_padded_instance, num_padded_instance * rotation_var); - - // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding - let num_instance_non_selected: Var = builder.eval( - chip_proof.sum_num_instances.clone() - * (rotation_var - rotation_subgroup_size - C::N::ONE), - ); - let new_multiplicity: Var = - builder.eval(num_lks * (num_padded_instance + num_instance_non_selected)); - builder.assign( - &dummy_table_item_multiplicity, - dummy_table_item_multiplicity + new_multiplicity, - ); - } + builder.assign(&num_padded_instance, num_padded_instance * rotation_var); + + // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding + let num_instance_non_selected: Var = builder.eval( + chip_proof.sum_num_instances.clone() + * (rotation_var - rotation_subgroup_size - C::N::ONE), + ); + let new_multiplicity: Var = + builder.eval(num_lks * (num_padded_instance + num_instance_non_selected)); + builder.assign( + &dummy_table_item_multiplicity, + dummy_table_item_multiplicity + new_multiplicity, + ); builder.cycle_tracker_start("Verify chip proof"); let ( @@ -336,9 +319,7 @@ pub fn verify_zkvm_proof>( builder, &mut chip_challenger, &chip_proof, - &zkvm_proof_input.pi_evals, - &zkvm_proof_input.raw_pi, - &zkvm_proof_input.raw_pi_num_variables, + &zkvm_proof_input.pi, &challenges, chip_vk, &unipoly_extrapolator, @@ -365,21 +346,23 @@ pub fn verify_zkvm_proof>( builder.assign(&chip_logup_sum, chip_logup_sum + p2 * q2.inverse()); }); - if circuit_vk.get_cs().is_with_lk_table() { - builder.assign(&logup_sum, logup_sum - chip_logup_sum); - } else { - builder.assign(&logup_sum, logup_sum + chip_logup_sum); - } + builder.assign(&logup_sum, logup_sum + chip_logup_sum); let point_clone: Array> = builder.eval(input_opening_point.clone()); + let (wits_in_evals, fixed_in_evals) = split_input_opening_evals( + builder, + &chip_proof, + circuit_vk.get_cs().num_witin(), + circuit_vk.get_cs().num_fixed(), + ); if circuit_vk.get_cs().num_witin() > 0 { let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { num_var: input_opening_point.len().get_var(), point_and_evals: PointAndEvalsVariable { point: PointVariable { fs: point_clone }, - evals: chip_proof.wits_in_evals, + evals: wits_in_evals, }, }); builder.set_value(&witin_openings, num_witin_openings.get_var(), witin_round); @@ -392,7 +375,7 @@ pub fn verify_zkvm_proof>( point: PointVariable { fs: input_opening_point, }, - evals: chip_proof.fixed_in_evals, + evals: fixed_in_evals, }, }); @@ -522,6 +505,17 @@ pub fn verify_zkvm_proof>( &unipoly_extrapolator, &mut challenger, ); + // Global-state expressions are defined over compact/query-order PI slots. + // Keep this aligned with ceno_zkvm verifier: [init_pc, init_cycle, end_pc, end_cycle]. + let global_state_pi_evals: Array> = builder.dyn_array(4); + [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .enumerate() + .for_each(|(i, idx)| { + let raw = builder.get(&zkvm_proof_input.pi, idx); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&global_state_pi_evals, i, eval); + }); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -529,7 +523,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.initial_global_state_expr, ); @@ -540,7 +534,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.finalize_global_state_expr, ); @@ -556,14 +550,35 @@ pub fn verify_zkvm_proof>( shard_ec_sum } +fn split_input_opening_evals( + builder: &mut Builder, + chip_proof: &ZKVMChipProofInputVariable, + num_witin: usize, + num_fixed: usize, +) -> (Array>, Array>) { + let last_layer_idx: Usize = + builder.eval(chip_proof.gkr_iop_proof.layer_proofs.len() - Usize::from(1)); + let last_layer = builder.get(&chip_proof.gkr_iop_proof.layer_proofs, last_layer_idx); + let main_evals = last_layer.main.evals; + + let wit_end = Usize::from(num_witin); + let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); + // Native verifier accepts extra trailing evals; only the prefix is consumed here. + // Keep recursion semantics aligned by slicing the required prefix. + let eval_prefix = main_evals.slice(builder, Usize::from(0), fixed_end.clone()); + + ( + eval_prefix.slice(builder, Usize::from(0), wit_end), + eval_prefix.slice(builder, Usize::from(num_witin), fixed_end), + ) +} + pub fn verify_chip_proof( circuit_name: &str, builder: &mut Builder, challenger: &mut DuplexChallengerVariable, chip_proof: &ZKVMChipProofInputVariable, - pi_evals: &Array>, - raw_pi: &Array>>, - raw_pi_num_variables: &Array>, + pi: &Array>, challenges: &Array>, vk: &VerifyingKey, unipoly_extrapolator: &UniPolyExtrapolator, @@ -709,6 +724,13 @@ pub fn verify_chip_proof( builder.set(&q_slice, idx_vec[0], cpt); }); let gkr_circuit = gkr_circuit.clone().unwrap(); + let circuit_pi_evals: Array> = + builder.dyn_array(Usize::from(cs.instance.len())); + for (i, instance) in cs.instance.iter().enumerate() { + let raw = builder.get(pi, instance.0); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&circuit_pi_evals, i, eval); + } let zero_bit_decomps: Array> = builder.dyn_array(32); let selector_ctxs: Vec> = if cs.ec_final_sum.is_empty() { @@ -807,11 +829,8 @@ pub fn verify_chip_proof( gkr_circuit, &chip_proof.gkr_iop_proof, challenges, - pi_evals, - raw_pi, - raw_pi_num_variables, + &circuit_pi_evals, &out_evals, - chip_proof, selector_ctxs, unipoly_extrapolator, poly_evaluator, @@ -829,13 +848,10 @@ pub fn verify_gkr_circuit( gkr_proof: &GKRProofVariable, challenges: &Array>, pub_io_evals: &Array>, - raw_pi: &Array>>, - raw_pi_num_variables: &Array>, claims: &Array>, - _chip_proof: &ZKVMChipProofInputVariable, selector_ctxs: Vec>, unipoly_extrapolator: &UniPolyExtrapolator, - poly_evaluator: &mut PolyEvaluator, + _poly_evaluator: &mut PolyEvaluator, ) -> PointVariable { let rt = PointVariable { fs: builder.dyn_array(0), @@ -1124,18 +1140,6 @@ pub fn verify_gkr_circuit( builder.assert_ext_eq(expected_eval, main_wit_eval); } - let pubio_offset = layer.n_witin + layer.n_fixed; - for (index, instance) in layer.instance_openings.iter().enumerate() { - let index: usize = pubio_offset + index; - let poly = builder.get(raw_pi, instance.0); - let num_variable = builder.get(raw_pi_num_variables, instance.0); - let in_point_slice = in_point.slice(builder, 0, num_variable); - let expected_eval = - poly_evaluator.evaluate_base_poly_at_point(builder, &poly, &in_point_slice); - let main_eval = builder.get(&main_evals, index); - builder.assert_ext_eq(expected_eval, main_eval); - } - // TODO: we should store alpha_pows in a bigger array to avoid concatenating them let main_sumcheck_challenges_len: Usize = builder.eval(alpha_pows.len() + Usize::from(2)); diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index f211c53ce..83b25cc3e 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -11,6 +11,8 @@ version = "0.1.0" [dependencies] ceno_serde = { path = "../ceno_serde" } +ceno_syscall.workspace = true getrandom = { version = "0.2.15", features = ["custom"], default-features = false } getrandom_v3 = { package = "getrandom", version = "0.3", default-features = false } serde.workspace = true +tiny-keccak.workspace = true diff --git a/ceno_rt/ceno_link.x b/ceno_rt/ceno_link.x index f4c633ba0..0aa7dd928 100644 --- a/ceno_rt/ceno_link.x +++ b/ceno_rt/ceno_link.x @@ -4,11 +4,7 @@ _hints_start = ORIGIN(REGION_HINTS) + 128M; _hints_length = 128M; _lengths_of_hints_start = ORIGIN(REGION_HINTS) + 128M; -_lengths_of_pubio_start = ORIGIN(REGION_PUBIO); -_pubio_start = ORIGIN(REGION_PUBIO); /* 0x20000000 */ -_pubio_end = ORIGIN(REGION_PUBIO) + 128M; /* PUBIO grows upward */ -_pubio_length = 128M; -_stack_start = ORIGIN(REGION_PUBIO) + 256M; /* stack grows downward */ +_stack_start = ORIGIN(REGION_STACK) + LENGTH(REGION_STACK); /* stack grows downward */ SECTIONS { @@ -25,10 +21,6 @@ SECTIONS *(.rodata .rodata.*); } > ROM - .pubio (NOLOAD): ALIGN(4) - { - *(.pubio .pubio.*); - } > STACK_PUBIO .stack (NOLOAD) : ALIGN(4) { diff --git a/ceno_rt/memory.x b/ceno_rt/memory.x index 2f95e9ae4..d42358be7 100644 --- a/ceno_rt/memory.x +++ b/ceno_rt/memory.x @@ -1,7 +1,7 @@ MEMORY { ROM (rx) : ORIGIN = 0x08000000, LENGTH = 128M - STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* PUBIO first 128M, Stack second 128M */ + STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* Stack region */ HINTS (r) : ORIGIN = 0x20000000, LENGTH = 256M /* will shift hint to 0x28000000 with 128M to reserve gap*/ RAM (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* heap/data/bss */ } @@ -9,7 +9,6 @@ MEMORY REGION_ALIAS("REGION_TEXT", ROM); REGION_ALIAS("REGION_RODATA", ROM); -REGION_ALIAS("REGION_PUBIO", STACK_PUBIO); REGION_ALIAS("REGION_STACK", STACK_PUBIO); REGION_ALIAS("REGION_HINTS", HINTS); diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index d3c08b298..dd0d4871f 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -13,7 +13,7 @@ use std::{ mod allocator; mod mmio; -pub use mmio::{commit, read, read_owned, read_slice}; +pub use mmio::{CommitCtx, commit, commit_digest, read, read_owned, read_slice}; mod io; #[cfg(debug_assertions)] diff --git a/ceno_rt/src/mmio.rs b/ceno_rt/src/mmio.rs index ead07a7ab..e3dd90801 100644 --- a/ceno_rt/src/mmio.rs +++ b/ceno_rt/src/mmio.rs @@ -1,8 +1,11 @@ //! Memory-mapped I/O (MMIO) functions. use ceno_serde::from_slice; +use ceno_syscall::syscall_pub_io_commit; use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; use serde::de::DeserializeOwned; +use std::vec::Vec; +use tiny_keccak::{Hasher, Keccak}; struct RegionState { next_len_at: *const usize, @@ -58,8 +61,6 @@ impl RegionState { unsafe extern "C" { static _hints_start: u8; static _lengths_of_hints_start: usize; - static _pubio_start: u8; - static _lengths_of_pubio_start: usize; } struct RegionStateCell(UnsafeCell); @@ -77,7 +78,6 @@ impl RegionStateCell { unsafe impl Sync for RegionStateCell {} static HINT_STATE: RegionStateCell = RegionStateCell::new(); -static PUBIO_STATE: RegionStateCell = RegionStateCell::new(); pub fn read_slice<'a>() -> &'a [u8] { unsafe { @@ -100,18 +100,110 @@ where read_owned() } -pub fn pubio_read_slice<'a>() -> &'a [u8] { - unsafe { - PUBIO_STATE - .with_mut(|state| state.take_slice(&raw const _lengths_of_pubio_start, &_pubio_start)) +fn digest_to_words(digest: [u8; 32]) -> [u32; 8] { + core::array::from_fn(|i| { + u32::from_le_bytes([ + digest[i * 4], + digest[i * 4 + 1], + digest[i * 4 + 2], + digest[i * 4 + 3], + ]) + }) +} + +fn keccak_words(bytes: &[u8]) -> [u32; 8] { + let mut keccak = Keccak::v256(); + keccak.update(bytes); + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + digest_to_words(digest) +} + +/// Commit a precomputed public-io digest. +/// +/// The input must already be an 8-word Keccak-256 digest encoded in little-endian words. +pub fn commit_digest(digest_words: [u32; 8]) { + syscall_pub_io_commit(&digest_words); +} + +/// Accumulates committed bytes and emits one digest when finalized. +#[derive(Clone, Debug, Default)] +pub struct CommitCtx { + bytes: Vec, +} + +impl CommitCtx { + pub fn new() -> Self { + Self::default() + } + + fn digest_words(&self) -> [u32; 8] { + keccak_words(&self.bytes) + } + + /// Append arbitrary bytes to this context. + pub fn commit(&mut self, data: &[u8]) { + self.bytes.extend_from_slice(data); + } + + /// Compute a final digest by hashing the accumulated bytes once. + pub fn finalized(self) { + commit_digest(self.digest_words()) } } -/// Read a value from public io, deserialize it, and assert that it matches the given value. -pub fn commit(v: &T) -where - T: DeserializeOwned + core::fmt::Debug + PartialEq, -{ - let expected: T = from_slice(pubio_read_slice()).expect("Deserialised value failed."); - assert_eq!(*v, expected); +/// Commit arbitrary public bytes by hashing with Keccak-256 and emitting digest limbs. +pub fn commit(data: &[u8]) { + commit_digest(keccak_words(data)); +} + +#[cfg(test)] +mod tests { + use super::{CommitCtx, digest_to_words, keccak_words}; + use tiny_keccak::{Hasher, Keccak}; + + #[test] + fn keccak_words_matches_manual_conversion() { + let words = keccak_words(b"hello world"); + + let mut manual = Keccak::v256(); + manual.update(b"hello world"); + let mut digest = [0u8; 32]; + manual.finalize(&mut digest); + + assert_eq!(words, digest_to_words(digest)); + } + + #[test] + fn commit_ctx_digest_words_hashes_all_appended_bytes() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"abc"); + ctx.commit(b"123"); + + let got = ctx.digest_words(); + let expected = keccak_words(b"abc123"); + + assert_eq!(got, expected); + } + + #[test] + fn commit_ctx_commit_appends_raw_bytes_before_finalize() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"hello"); + ctx.commit(b" "); + ctx.commit(b"world"); + + let got = ctx.digest_words(); + let expected = keccak_words(b"hello world"); + + assert_eq!(got, expected); + } + + #[test] + #[should_panic(expected = "syscall_pub_io_commit should only run inside zkvm")] + fn commit_ctx_finalized_is_callable() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"payload"); + ctx.finalized(); + } } diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 3823219d7..ab7e254d3 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -112,8 +112,7 @@ fn bench_add(c: &mut Criterion) { fixed: vec![], witness: polys, structural_witness: vec![], - public_input: vec![], - pub_io_evals: vec![], + pi: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 09e5a1131..67af854e6 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -19,7 +19,6 @@ use gkr_iop::hal::ProverBackend; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; -use p3::field::FieldAlgebra; use serde::{Serialize, de::DeserializeOwned}; use std::{fs, panic, panic::AssertUnwindSafe, path::PathBuf}; use tracing::{error, level_filters::LevelFilter}; @@ -175,26 +174,8 @@ fn main() { .with(args.profiling.is_none().then_some(default_filter)) .init(); - // process public input first - let public_io = args - .public_io - .and_then(|public_io| { - // if the vector contains only one element, write it as a raw `u32` - // otherwise, write the entire vector - // in both cases, convert the resulting `CenoStdin` into a `Vec` - if public_io.len() == 1 { - CenoStdin::default() - .write(&public_io[0]) - .ok() - .map(|stdin| Into::>::into(&*stdin)) - } else { - CenoStdin::default() - .write(&public_io) - .ok() - .map(|stdin| Into::>::into(&*stdin)) - } - }) - .unwrap_or_default(); + // process public input first; this is raw u32 public input, not pre-digested words. + let public_io = args.public_io.unwrap_or_default(); assert!( public_io.len() <= args.public_io_size as usize / WORD_SIZE, "require pub io length {} < max public_io_size {}", @@ -404,8 +385,7 @@ fn soundness_test>( // do sanity check let transcript = Transcript::new(b"riscv"); // change public input maliciously should cause verifier to reject proof - zkvm_proof.raw_pi[0] = vec![E::BaseField::ONE]; - zkvm_proof.raw_pi[1] = vec![E::BaseField::ONE]; + zkvm_proof.public_values.exit_code = 1; // capture panic message, if have let result = with_panic_hook(Box::new(|_info| ()), || { @@ -428,7 +408,7 @@ fn soundness_test>( unreachable!() }; - if !msg.starts_with("0th round's prover message is not consistent with the claim") { + if !msg.starts_with("assertion `left == right` failed") { error!("unknown panic {msg:?}"); panic::resume_unwind(err); }; diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bb63ce504..ef2a5f9bf 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,8 +5,8 @@ use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, - HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, + PUBIO_DIGEST_U16_LIMBS, PUBLIC_IO_IDX, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, }, scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, @@ -24,15 +24,19 @@ pub trait PublicValuesQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; + #[allow(dead_code)] fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + fn query_public_io_digest( + &mut self, + ) -> Result<[Instance; PUBIO_DIGEST_U16_LIMBS], CircuitBuilderError>; #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; - fn query_heap_start_addr(&self) -> Result; + fn query_heap_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_heap_shard_len(&self) -> Result; - fn query_hint_start_addr(&self) -> Result; + fn query_heap_shard_len(&mut self) -> Result; + fn query_hint_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_hint_shard_len(&self) -> Result; + fn query_hint_shard_len(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -86,27 +90,38 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ - self.cs.query_instance_for_openings(PUBLIC_IO_IDX)?, - self.cs.query_instance_for_openings(PUBLIC_IO_IDX + 1)?, + self.cs.query_instance(PUBLIC_IO_IDX)?, + self.cs.query_instance(PUBLIC_IO_IDX + 1)?, ]) } + fn query_public_io_digest( + &mut self, + ) -> Result<[Instance; PUBIO_DIGEST_U16_LIMBS], CircuitBuilderError> { + let limbs = (0..PUBIO_DIGEST_U16_LIMBS) + .map(|i| self.cs.query_instance(PUBIO_DIGEST_IDX + i)) + .collect::, _>>()?; + Ok(limbs + .try_into() + .expect("pubio digest instance limb count must be fixed")) + } + fn query_shard_id(&mut self) -> Result { self.cs.query_instance(SHARD_ID_IDX) } - fn query_heap_start_addr(&self) -> Result { + fn query_heap_start_addr(&mut self) -> Result { self.cs.query_instance(HEAP_START_ADDR_IDX) } - fn query_heap_shard_len(&self) -> Result { + fn query_heap_shard_len(&mut self) -> Result { self.cs.query_instance(HEAP_LENGTH_IDX) } - fn query_hint_start_addr(&self) -> Result { + fn query_hint_start_addr(&mut self) -> Result { self.cs.query_instance(HINT_START_ADDR_IDX) } - fn query_hint_shard_len(&self) -> Result { + fn query_hint_shard_len(&mut self) -> Result { self.cs.query_instance(HINT_LENGTH_IDX) } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..1fbce7251 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -34,9 +34,9 @@ use ff_ext::{ExtensionField, SmallField}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; use gkr_iop::{RAMType, hal::ProverBackend}; +use itertools::Itertools; #[cfg(debug_assertions)] -use itertools::MinMaxResult; -use itertools::{Itertools, chain}; +use itertools::{MinMaxResult, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; use rustc_hash::FxHashSet; @@ -49,6 +49,7 @@ use std::{ ops::Range, sync::Arc, }; +use tiny_keccak::{Hasher, Keccak}; use tracing::info_span; use transcript::BasicTranscript as Transcript; use witness::next_pow2_instance_padding; @@ -57,6 +58,19 @@ use witness::next_pow2_instance_padding; pub const DEFAULT_MAX_CELLS_PER_SHARDS: u64 = (1 << 30) * 16 / 4 / 2; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 29; pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20; + +pub fn public_io_words_to_digest_words(words: &[u32]) -> [u32; 8] { + let mut keccak = Keccak::v256(); + for word in words { + keccak.update(&word.to_le_bytes()); + } + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + + // Reinterpret Keccak digest bytes as 8 little-endian u32 words. + unsafe { core::mem::transmute::<[u8; 32], [u32; 8]>(digest) } +} + // define a relative small number to make first shard handle much less instruction /// The polynomial commitment scheme kind #[derive( @@ -770,7 +784,7 @@ impl StepReplay { ) -> Self { let mut vm = VMState::new_with_tracer_config(platform, program, FullTracerConfig { max_step_shard }); - for record in chain!(init_mem_state.hints.iter(), init_mem_state.io.iter()) { + for record in init_mem_state.hints.iter() { vm.init_memory(record.addr.into(), record.value); } StepReplay { @@ -828,13 +842,14 @@ pub fn emulate_program<'a>( program: Arc, max_steps: usize, init_mem_state: &InitMemState, + public_io_digest_input: &[u32], platform: &Platform, multi_prover: &MultiProver, step_cell_extractor: Arc, ) -> EmulationResult<'a> { let InitMemState { mem: mem_init, - io: io_init, + io: _, reg: reg_init, hints: hints_init, stack: _, @@ -853,7 +868,7 @@ pub fn emulate_program<'a>( }); info_span!("[ceno] emulator.init_mem").in_scope(|| { - for record in chain!(hints_init, io_init) { + for record in hints_init { vm.init_memory(record.addr.into(), record.value); } }); @@ -937,17 +952,8 @@ pub fn emulate_program<'a>( }) .collect_vec(); - // Find the final public IO cycles. - let io_final = io_init - .iter() - .map(|rec| MemFinalRecord { - ram_type: RAMType::Memory, - addr: rec.addr, - value: rec.value, - init_value: rec.value, - cycle: final_access.cycle(rec.addr.into()), - }) - .collect_vec(); + // Legacy public-io memory init is removed. + let io_final: Vec = vec![]; // Find the final hints IO cycles. let hints_final = hints_init @@ -1020,8 +1026,8 @@ pub fn emulate_program<'a>( heap_final.len() as u32, platform.hints.start, hints_final.len() as u32, - io_init.iter().map(|rec| rec.value).collect_vec(), - vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity + public_io_words_to_digest_words(public_io_digest_input), + [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); #[cfg(debug_assertions)] @@ -1139,17 +1145,13 @@ fn setup_platform_inner( heap.start..heap_end as u32 }; - assert!( - pub_io_size.is_power_of_two(), - "pub io size {pub_io_size} must be a power of two" - ); + let _ = pub_io_size; let platform = Platform { rom: program.base_address ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, prog_data, stack, heap, - public_io: preset.public_io.start..preset.public_io.start + pub_io_size, ..preset }; assert!( @@ -1212,7 +1214,6 @@ pub fn generate_fixed_traces( system_config: &ConstraintSystemConfig, reg_init: &[MemInitRecord], static_mem_init: &[MemInitRecord], - io_addrs: &[Addr], program: &Program, ) -> ZKVMFixedTraces { let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -1231,7 +1232,6 @@ pub fn generate_fixed_traces( &mut zkvm_fixed_traces, reg_init, static_mem_init, - io_addrs, ); system_config .dummy_config @@ -1406,7 +1406,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &pi, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, &emul_result.final_mem_state.stack, ) .unwrap(); @@ -1421,7 +1420,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &[], &[], &[], - &[], ) .unwrap(); } @@ -1451,7 +1449,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &pi, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, &emul_result.final_mem_state.hints, &emul_result.final_mem_state.stack, &emul_result.final_mem_state.heap, @@ -1523,7 +1520,6 @@ pub struct E2EProgramCtx { pub pubio_len: usize, pub system_config: ConstraintSystemConfig, pub reg_init: Vec, - pub io_init: Vec, pub zkvm_fixed_traces: ZKVMFixedTraces, } @@ -1552,7 +1548,7 @@ pub fn setup_program( multi_prover: MultiProver, ) -> E2EProgramCtx { let static_addrs = init_static_addrs(&program); - let pubio_len = platform.public_io.iter_addresses().len(); + let pubio_len = 0; let program_params = ProgramParams { platform: platform.clone(), program_size: next_pow2_instance_padding(program.instructions.len()), @@ -1561,16 +1557,9 @@ pub fn setup_program( }; let system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); - let io_init = MemPadder::new_mem_records_uninit(platform.public_io.clone(), pubio_len); - // Generate fixed traces - let zkvm_fixed_traces = generate_fixed_traces( - &system_config, - ®_init, - &static_addrs, - &io_init.iter().map(|rec| rec.addr).collect_vec(), - &program, - ); + let zkvm_fixed_traces = + generate_fixed_traces(&system_config, ®_init, &static_addrs, &program); E2EProgramCtx { program: Arc::new(program), @@ -1580,7 +1569,6 @@ pub fn setup_program( pubio_len, system_config, reg_init, - io_init, zkvm_fixed_traces, } } @@ -1634,9 +1622,8 @@ impl E2EProgramCtx { } /// Setup init mem state - pub fn setup_init_mem(&self, hints: &[u32], public_io: &[u32]) -> InitMemState { - let mut io_init = self.io_init.clone(); - MemPadder::init_mem_records(&mut io_init, public_io); + pub fn setup_init_mem(&self, hints: &[u32], public_io_digest_input: &[u32]) -> InitMemState { + let _ = public_io_digest_input; let hint_init = MemPadder::new_mem_records( self.platform.hints.clone(), hints.len().next_power_of_two(), @@ -1646,7 +1633,7 @@ impl E2EProgramCtx { InitMemState { mem: self.static_addrs.clone(), reg: self.reg_init.clone(), - io: io_init, + io: vec![], hints: hint_init, // stack/heap both init value 0 and range is dynamic stack: vec![], @@ -1678,7 +1665,7 @@ pub fn run_e2e_with_checkpoint< platform: Platform, multi_prover: MultiProver, hints: &[u32], - public_io: &[u32], + public_io_digest_input: &[u32], max_steps: usize, checkpoint: Checkpoint, // for debug purpose @@ -1697,12 +1684,13 @@ pub fn run_e2e_with_checkpoint< let prover = ZKVMProver::new(pk.into(), device); let start = std::time::Instant::now(); - let init_full_mem = prover.setup_init_mem(hints, public_io); + let init_full_mem = prover.setup_init_mem(hints, public_io_digest_input); tracing::debug!("setup_init_mem done in {:?}", start.elapsed()); // Generate witness let is_mock_proving = std::env::var("MOCK_PROVING").is_ok(); if let Checkpoint::PrepE2EProving = checkpoint { + let public_io_digest_input_owned = public_io_digest_input.to_vec(); return E2ECheckpointResult { proofs: None, vk: Some(vk), @@ -1710,6 +1698,7 @@ pub fn run_e2e_with_checkpoint< _ = run_e2e_proof::( &prover, &init_full_mem, + &public_io_digest_input_owned, max_steps, is_mock_proving, target_shard_id, @@ -1727,6 +1716,7 @@ pub fn run_e2e_with_checkpoint< prover.pk.program_ctx.as_ref().unwrap().program.clone(), max_steps, &init_full_mem, + public_io_digest_input, &prover.pk.program_ctx.as_ref().unwrap().platform, &prover.pk.program_ctx.as_ref().unwrap().multi_prover, step_cell_extractor, @@ -1807,6 +1797,7 @@ pub fn run_e2e_proof< >( prover: &ZKVMProver, init_full_mem: &InitMemState, + public_io_digest_input: &[u32], max_steps: usize, is_mock_proving: bool, // for debug purpose @@ -1820,6 +1811,7 @@ pub fn run_e2e_proof< ctx.program.clone(), max_steps, init_full_mem, + public_io_digest_input, &ctx.platform, &ctx.multi_prover, step_cell_extractor, @@ -2125,6 +2117,7 @@ mod tests { use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; use itertools::Itertools; use std::sync::Arc; + use tiny_keccak::{Hasher, Keccak}; #[test] fn test_single_prover_shard_ctx() { @@ -2254,4 +2247,27 @@ mod tests { } } } + + #[test] + fn public_io_digest_matches_guest_commit_hashing() { + let words = vec![4191u32]; + let digest_words = super::public_io_words_to_digest_words(&words); + + let mut keccak = Keccak::v256(); + for word in &words { + keccak.update(&word.to_le_bytes()); + } + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + let expected = core::array::from_fn(|i| { + u32::from_le_bytes([ + digest[i * 4], + digest[i * 4 + 1], + digest[i * 4 + 2], + digest[i * 4 + 3], + ]) + }); + + assert_eq!(digest_words, expected); + } } diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 61e673246..09feb6280 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,8 +1,10 @@ -use crate::uint::UIntLimbs; +use crate::{scheme::constants::SEPTIC_EXTENSION_DEGREE, uint::UIntLimbs}; pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; pub const EXIT_PC: usize = 0; + +/// scalar-based public value, id start from 0 pub const EXIT_CODE_IDX: usize = 0; // exit code u32 occupied 2 limb, each with 16 pub const INIT_PC_IDX: usize = EXIT_CODE_IDX + 2; @@ -14,8 +16,13 @@ pub const HEAP_START_ADDR_IDX: usize = SHARD_ID_IDX + 1; pub const HEAP_LENGTH_IDX: usize = HEAP_START_ADDR_IDX + 1; pub const HINT_START_ADDR_IDX: usize = HEAP_LENGTH_IDX + 1; pub const HINT_LENGTH_IDX: usize = HINT_START_ADDR_IDX + 1; -pub const PUBLIC_IO_IDX: usize = HINT_LENGTH_IDX + 1; -pub const SHARD_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; + +pub const SHARD_RW_SUM_IDX: usize = HINT_LENGTH_IDX + 1; +pub const PUBIO_DIGEST_IDX: usize = SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2; +pub const PUBIO_DIGEST_U16_LIMBS: usize = 8 * UINT_LIMBS; + +/// vector-based public value, id start from 0 +pub const PUBLIC_IO_IDX: usize = 0; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 84a836bf3..f286a8173 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -3,6 +3,7 @@ mod fptower_fp2_add; mod fptower_fp2_mul; mod halt; mod keccak; +mod pubio_commit; mod sha_extend; mod uint256; mod weierstrass_add; @@ -13,6 +14,7 @@ pub use fptower_fp::{FpAddInstruction, FpMulInstruction}; pub use fptower_fp2_add::Fp2AddInstruction; pub use fptower_fp2_mul::Fp2MulInstruction; pub use keccak::KeccakInstruction; +pub use pubio_commit::PubIoCommitInstruction; pub use sha_extend::ShaExtendInstruction; pub use uint256::{Secp256k1InvInstruction, Secp256r1InvInstruction, Uint256MulInstruction}; pub use weierstrass_add::WeierstrassAddAssignInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs new file mode 100644 index 000000000..8ca50708b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs @@ -0,0 +1,161 @@ +use std::marker::PhantomData; + +use ceno_emul::{ + Change, InsnKind, Platform, PubIoCommitSpec, StepRecord, SyscallSpec, WORD_SIZE, WriteOp, +}; +use ff_ext::ExtensionField; +use multilinear_extensions::ToExpr; +use p3::field::FieldAlgebra; + +use crate::{ + chip_handler::general::InstFetch, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{LIMB_BITS, LIMB_MASK, MEM_BITS, UInt}, + ecall_base::OpFixedRS, + insn_base::{MemAddr, StateInOut, WriteMEM}, + }, + }, + precompiles::{PUBIO_COMMIT_WORDS, PubioCommitLayout}, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; + +#[derive(Debug)] +pub struct EcallPubioCommitConfig { + vm_state: StateInOut, + ecall_id: OpFixedRS, + digest_ptr: (OpFixedRS, MemAddr), + mem_rw: [WriteMEM; PUBIO_COMMIT_WORDS], +} + +pub struct PubIoCommitInstruction(PhantomData); + +impl Instruction for PubIoCommitInstruction { + type InstructionConfig = EcallPubioCommitConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } + + fn name() -> String { + "Ecall_PubioCommit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + let vm_state = StateInOut::construct_circuit(cb, false)?; + let syscall_code = PubIoCommitSpec::CODE; + + let ecall_id = OpFixedRS::<_, { Platform::reg_ecall() }, false>::construct_circuit( + cb, + UInt::from_const_unchecked(vec![ + syscall_code & LIMB_MASK, + (syscall_code >> LIMB_BITS) & LIMB_MASK, + ]) + .register_expr(), + vm_state.ts, + )?; + + let digest_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; + let digest_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + cb, + digest_ptr_value.uint_unaligned().register_expr(), + vm_state.ts, + )?; + + cb.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + InsnKind::ECALL.into(), + None, + 0.into(), + 0.into(), + 0.into(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), + ))?; + + let layout = PubioCommitLayout::construct_circuit(cb)?; + let mem_rw: [WriteMEM; PUBIO_COMMIT_WORDS] = (0..PUBIO_COMMIT_WORDS) + .map(|i| { + WriteMEM::construct_circuit( + cb, + digest_ptr.prev_value.as_ref().unwrap().value() + + E::BaseField::from_canonical_u32((i * WORD_SIZE) as u32).expr(), + layout.digest_words[i].clone(), + layout.digest_words[i].clone(), + vm_state.ts, + ) + }) + .collect::, _>>()? + .try_into() + .expect("pubio read width is fixed"); + + Ok(EcallPubioCommitConfig { + vm_state, + ecall_id, + digest_ptr: (digest_ptr, digest_ptr_value), + mem_rw, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let syscall_code = PubIoCommitSpec::CODE; + let ops = step.syscall().expect("syscall step"); + assert_eq!(ops.reg_ops.len(), 1, "PUB_IO_COMMIT expects 1 reg op"); + assert_eq!( + ops.mem_ops.len(), + PUBIO_COMMIT_WORDS, + "PUB_IO_COMMIT expects {} mem ops", + PUBIO_COMMIT_WORDS + ); + + config.vm_state.assign_instance(instance, shard_ctx, step)?; + + config.ecall_id.assign_op( + instance, + shard_ctx, + lk_multiplicity, + step.cycle(), + &WriteOp::new_register_op( + Platform::reg_ecall(), + Change::new(syscall_code, syscall_code), + step.rs1().unwrap().previous_cycle, + ), + )?; + + config.digest_ptr.1.assign_instance( + instance, + lk_multiplicity, + ops.reg_ops[0].value.after, + )?; + config.digest_ptr.0.assign_op( + instance, + shard_ctx, + lk_multiplicity, + step.cycle(), + &ops.reg_ops[0], + )?; + + for (writer, op) in config.mem_rw.iter().zip(&ops.mem_ops) { + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; + } + + lk_multiplicity.fetch(step.pc().before.0); + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..d1c0287d4 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -335,9 +335,20 @@ impl ReadMEM { step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + } + + pub fn assign_op( + &self, + instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + cycle: Cycle, + op: &WriteOp, + ) -> Result<(), ZKVMError> { let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let shard_cycle = step.cycle() - current_shard_offset_cycle; + let shard_cycle = cycle - current_shard_offset_cycle; // Memory state set_val!(instance, self.prev_ts, shard_prev_cycle); @@ -353,7 +364,7 @@ impl ReadMEM { RAMType::Memory, op.addr, op.addr.baddr().0 as u64, - step.cycle() + Tracer::SUBCYCLE_MEM, + cycle + Tracer::SUBCYCLE_MEM, op.previous_cycle, op.value.after, None, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..90d53fc21 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -21,9 +21,10 @@ use crate::{ div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, ecall::{ Fp2AddInstruction, Fp2MulInstruction, FpAddInstruction, FpMulInstruction, - KeccakInstruction, Secp256k1InvInstruction, Secp256r1InvInstruction, - ShaExtendInstruction, Uint256MulInstruction, WeierstrassAddAssignInstruction, - WeierstrassDecompressInstruction, WeierstrassDoubleAssignInstruction, + KeccakInstruction, PubIoCommitInstruction, Secp256k1InvInstruction, + Secp256r1InvInstruction, ShaExtendInstruction, Uint256MulInstruction, + WeierstrassAddAssignInstruction, WeierstrassDecompressInstruction, + WeierstrassDoubleAssignInstruction, }, logic::{AndInstruction, OrInstruction, XorInstruction}, logic_imm::{AndiInstruction, OriInstruction, XoriInstruction}, @@ -45,10 +46,10 @@ use ceno_emul::{ Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, InsnKind::{self, *}, - KeccakSpec, LogPcCycleSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, - Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, Secp256r1DoubleSpec, - Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, StepRecord, - SyscallSpec, Uint256MulSpec, Word, + KeccakSpec, LogPcCycleSpec, Platform, PubIoCommitSpec, Secp256k1AddSpec, + Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, + Secp256r1DoubleSpec, Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, + StepRecord, SyscallSpec, Uint256MulSpec, Word, }; use dummy::LargeEcallDummy; use ff_ext::ExtensionField; @@ -73,6 +74,7 @@ use strum::{EnumCount, IntoEnumIterator}; pub mod mmu; const ECALL_HALT: u32 = Platform::ecall_halt(); +const ECALL_PUB_IO_COMMIT: u32 = PubIoCommitSpec::CODE; pub struct Rv32imConfig { // ALU Opcodes. @@ -134,6 +136,7 @@ pub struct Rv32imConfig { // Ecall Opcodes pub halt_config: as Instruction>::InstructionConfig, + pub pubio_commit_config: as Instruction>::InstructionConfig, pub keccak_config: as Instruction>::InstructionConfig, pub sha_extend_config: as Instruction>::InstructionConfig, pub bn254_add_config: @@ -355,6 +358,8 @@ impl Rv32imConfig { }}; } let halt_config = register_ecall_circuit!(HaltInstruction, ecall_cells_map); + let pubio_commit_config = + register_ecall_circuit!(PubIoCommitInstruction, ecall_cells_map); // Keccak precompile is a known hotspot for peak memory. // Its heavy read/write/LK activity inflates tower-witness usage, causing @@ -468,6 +473,7 @@ impl Rv32imConfig { lb_config, // ecall opcodes halt_config, + pubio_commit_config, keccak_config, sha_extend_config, bn254_add_config, @@ -562,6 +568,7 @@ impl Rv32imConfig { // system fixed.register_opcode_circuit::>(cs, &self.halt_config); + fixed.register_opcode_circuit::>(cs, &self.pubio_commit_config); fixed.register_opcode_circuit::>(cs, &self.keccak_config); fixed.register_opcode_circuit::>(cs, &self.sha_extend_config); fixed.register_opcode_circuit::>>( @@ -650,6 +657,7 @@ impl Rv32imConfig { } log_ecall!("HALT", ECALL_HALT); + log_ecall!("PUB_IO_COMMIT", ECALL_PUB_IO_COMMIT); log_ecall!("KECCAK", KeccakSpec::CODE); log_ecall!("bn254_add_records", Bn254AddSpec::CODE); log_ecall!("bn254_double_records", Bn254DoubleSpec::CODE); @@ -761,6 +769,11 @@ impl Rv32imConfig { // ecall / halt assign_ecall!(HaltInstruction, halt_config, ECALL_HALT); + assign_ecall!( + PubIoCommitInstruction, + pubio_commit_config, + ECALL_PUB_IO_COMMIT + ); assign_ecall!(KeccakInstruction, keccak_config, KeccakSpec::CODE); assign_ecall!( WeierstrassAddAssignInstruction>, @@ -1042,6 +1055,10 @@ impl Rv32imConfig { .ecall_cells_map .get(&HaltInstruction::::name()) .expect("unable to find name"), + ECALL_PUB_IO_COMMIT => *self + .ecall_cells_map + .get(&PubIoCommitInstruction::::name()) + .expect("unable to find name"), KeccakSpec::CODE => *self .ecall_cells_map .get(&KeccakInstruction::::name()) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index b3f96feb0..5b264dad9 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -5,9 +5,9 @@ use crate::{ structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsInitCircuit, HintsTable, - LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOInitCircuit, - PubIOTable, RegTable, RegTableInitCircuit, ShardRamCircuit, StackInitCircuit, StackTable, - StaticMemInitCircuit, StaticMemTable, TableCircuit, + LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, RegTable, + RegTableInitCircuit, ShardRamCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, + StaticMemTable, TableCircuit, }, }; use ceno_emul::{Addr, IterAddresses, WORD_SIZE, Word}; @@ -20,8 +20,6 @@ pub struct MmuConfig { pub reg_init_config: as TableCircuit>::TableConfig, /// Initialization of memory with static addresses. pub static_mem_init_config: as TableCircuit>::TableConfig, - /// Initialization of public IO. - pub public_io_init_config: as TableCircuit>::TableConfig, /// Initialization of hints. pub hints_init_config: as TableCircuit>::TableConfig, /// Initialization of heap. @@ -41,8 +39,6 @@ impl MmuConfig { let static_mem_init_config = cs.register_table_circuit::>(); - let public_io_init_config = cs.register_table_circuit::>(); - let hints_init_config = cs.register_table_circuit::>(); let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); @@ -52,7 +48,6 @@ impl MmuConfig { Self { reg_init_config, static_mem_init_config, - public_io_init_config, hints_init_config, stack_init_config, heap_init_config, @@ -68,12 +63,10 @@ impl MmuConfig { fixed: &mut ZKVMFixedTraces, reg_init: &[MemInitRecord], static_mem_init: &[MemInitRecord], - io_addrs: &[Addr], ) { assert!( chain!( static_mem_init.iter_addresses(), - io_addrs.iter_addresses(), // TODO: optimize with min_max and Range. self.params.platform.hints.iter_addresses(), ) @@ -88,12 +81,6 @@ impl MmuConfig { &self.static_mem_init_config, static_mem_init, ); - - fixed.register_table_circuit::>( - cs, - &self.public_io_init_config, - io_addrs, - ); fixed.register_table_circuit::>(cs, &self.hints_init_config, &()); fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); @@ -130,7 +117,6 @@ impl MmuConfig { pv: &PublicValues, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], - io_final: &[MemFinalRecord], stack_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>( @@ -145,12 +131,6 @@ impl MmuConfig { static_mem_final, )?; - witness.assign_table_circuit::>( - cs, - &self.public_io_init_config, - io_final, - )?; - witness.assign_table_circuit::>( cs, &self.stack_init_config, @@ -168,13 +148,11 @@ impl MmuConfig { pv: &PublicValues, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], - io_final: &[MemFinalRecord], hints_final: &[MemFinalRecord], stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { let all_records = vec![ - (PubIOTable::name(), None, io_final), (RegTable::name(), None, reg_final), (StaticMemTable::name(), None, static_mem_final), (StackTable::name(), None, stack_final), @@ -224,10 +202,6 @@ impl MmuConfig { pub fn static_mem_len(&self) -> usize { ::len(&self.params) } - - pub fn public_io_len(&self) -> usize { - ::len(&self.params) - } } pub struct MemPadder { diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index 78c8bfe95..76f6eb4a5 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -1003,7 +1003,6 @@ pub fn run_keccakf + 'stat &out_evals, &[], &[], - &[], &mut verifier_transcript, &selector_ctxs, ) diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index ba7c04308..a38b81e24 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -478,7 +478,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index 76d32b31a..e11a3fc13 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -523,7 +523,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index c9160e6d9..352a01a54 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -536,7 +536,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 52391267a..5dda9f636 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -1263,7 +1263,6 @@ pub fn run_lookup_keccakf gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 3d9a6e545..f282efddb 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -1,6 +1,7 @@ mod bitwise_keccakf; mod fptower; mod lookup_keccakf; +mod pubio_commit; mod sha256; mod uint256; mod utils; @@ -12,6 +13,7 @@ pub use lookup_keccakf::{ ROUNDS as KECCAK_ROUNDS, ROUNDS_CEIL_LOG2 as KECCAK_ROUNDS_CEIL_LOG2, XOR_LOOKUPS, run_lookup_keccakf, setup_gkr_circuit as setup_lookup_keccak_gkr_circuit, }; +pub use pubio_commit::{PUBIO_COMMIT_WORDS, PUBIO_DIGEST_U16_LIMBS, PubioCommitLayout}; pub use bitwise_keccakf::{ KeccakLayout as BitwiseKeccakLayout, run_keccakf as run_bitwise_keccakf, diff --git a/ceno_zkvm/src/precompiles/pubio_commit.rs b/ceno_zkvm/src/precompiles/pubio_commit.rs new file mode 100644 index 000000000..35cf93eb4 --- /dev/null +++ b/ceno_zkvm/src/precompiles/pubio_commit.rs @@ -0,0 +1,37 @@ +use ff_ext::ExtensionField; +use multilinear_extensions::{Instance, ToExpr}; + +use crate::{ + chip_handler::{MemoryExpr, general::PublicValuesQuery}, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::riscv::constants::UINT_LIMBS, +}; + +pub const PUBIO_COMMIT_WORDS: usize = 8; +pub const PUBIO_DIGEST_U16_LIMBS: usize = PUBIO_COMMIT_WORDS * UINT_LIMBS; + +#[derive(Debug)] +pub struct PubioCommitLayout { + /// Public digest instances laid out as 16-bit limbs, little-endian per word. + pub digest_u16_limbs: [Instance; PUBIO_DIGEST_U16_LIMBS], + pub digest_words: [MemoryExpr; PUBIO_COMMIT_WORDS], +} + +impl PubioCommitLayout { + pub fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let digest_u16_limbs = cb.query_public_io_digest()?; + let digest_words = core::array::from_fn(|word_idx| { + let limb_base = word_idx * UINT_LIMBS; + [ + digest_u16_limbs[limb_base].expr(), + digest_u16_limbs[limb_base + 1].expr(), + ] + }); + + Ok(Self { + digest_u16_limbs, + digest_words, + }) + } +} diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs index 235e37b95..3f87e2395 100644 --- a/ceno_zkvm/src/precompiles/sha256/extend.rs +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -554,7 +554,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index e59e97c55..d8e700faf 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -944,7 +944,6 @@ pub fn run_uint256_mul + ' gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 012dcab80..40cc0420d 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -778,7 +778,6 @@ pub fn run_weierstrass_add< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index d6400a2d7..8b3cfa32e 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -758,7 +758,6 @@ pub fn run_weierstrass_decompress< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 686baa397..f01b2727a 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -780,7 +780,6 @@ pub fn run_weierstrass_double< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index fa9127f10..66e51912b 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -18,10 +18,16 @@ use crate::{ instructions::{ Instruction, riscv::{ - constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, + constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, LIMB_BITS, + LIMB_MASK, PUBIO_DIGEST_IDX, PUBIO_DIGEST_U16_LIMBS, SHARD_ID_IDX, + SHARD_RW_SUM_IDX, UINT_LIMBS, + }, ecall::HaltInstruction, }, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, structs::{TowerProofs, ZKVMVerifyingKey}, }; @@ -65,13 +71,10 @@ pub struct ZKVMChipProof { pub ecc_proof: Option>, pub num_instances: Vec, - - pub fixed_in_evals: Vec, - pub wits_in_evals: Vec, } /// each field will be interpret to (constant) polynomial -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicValues { pub exit_code: u32, pub init_pc: u32, @@ -83,11 +86,43 @@ pub struct PublicValues { pub heap_shard_len: u32, pub hint_start_addr: u32, pub hint_shard_len: u32, - pub public_io: Vec, - pub shard_rw_sum: Vec, + pub public_io_digest: [u32; 8], + pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } impl PublicValues { + pub const fn flattened_len() -> usize { + PUBIO_DIGEST_IDX + PUBIO_DIGEST_U16_LIMBS + } + + pub fn iter_field<'a, Base: FieldAlgebra + 'a>(&'a self) -> impl Iterator + 'a { + [ + Base::from_canonical_u32(self.exit_code & 0xffff), + Base::from_canonical_u32((self.exit_code >> 16) & 0xffff), + Base::from_canonical_u32(self.init_pc), + Base::from_canonical_u64(self.init_cycle), + Base::from_canonical_u32(self.end_pc), + Base::from_canonical_u64(self.end_cycle), + Base::from_canonical_u32(self.shard_id), + Base::from_canonical_u32(self.heap_start_addr), + Base::from_canonical_u32(self.heap_shard_len), + Base::from_canonical_u32(self.hint_start_addr), + Base::from_canonical_u32(self.hint_shard_len), + ] + .into_iter() + .chain( + self.shard_rw_sum + .iter() + .map(|value| Base::from_canonical_u32(*value)), + ) + .chain(self.public_io_digest.iter().flat_map(|word| { + [ + Base::from_canonical_u32(word & 0xffff), + Base::from_canonical_u32((word >> 16) & 0xffff), + ] + })) + } + #[allow(clippy::too_many_arguments)] pub fn new( exit_code: u32, @@ -100,8 +135,8 @@ impl PublicValues { heap_shard_len: u32, hint_start_addr: u32, hint_shard_len: u32, - public_io: Vec, - shard_rw_sum: Vec, + public_io_digest: [u32; 8], + shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { Self { exit_code, @@ -114,49 +149,42 @@ impl PublicValues { heap_shard_len, hint_start_addr, hint_shard_len, - public_io, + public_io_digest, shard_rw_sum, } } - pub fn to_vec(&self) -> Vec> { - vec![ - vec![E::BaseField::from_canonical_u32(self.exit_code & 0xffff)], - vec![E::BaseField::from_canonical_u32( - (self.exit_code >> 16) & 0xffff, - )], - vec![E::BaseField::from_canonical_u32(self.init_pc)], - vec![E::BaseField::from_canonical_u64(self.init_cycle)], - vec![E::BaseField::from_canonical_u32(self.end_pc)], - vec![E::BaseField::from_canonical_u64(self.end_cycle)], - vec![E::BaseField::from_canonical_u32(self.shard_id)], - vec![E::BaseField::from_canonical_u32(self.heap_start_addr)], - vec![E::BaseField::from_canonical_u32(self.heap_shard_len)], - vec![E::BaseField::from_canonical_u32(self.hint_start_addr)], - vec![E::BaseField::from_canonical_u32(self.hint_shard_len)], - ] - .into_iter() - .chain( - // public io processed into UINT_LIMBS column - (0..UINT_LIMBS) - .map(|limb_index| { - self.public_io - .iter() - .map(|value| { - E::BaseField::from_canonical_u16( - ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, - ) - }) - .collect_vec() - }) - .collect_vec(), - ) - .chain( - self.shard_rw_sum - .iter() - .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) - .collect_vec(), - ) - .collect::>() + pub fn query_by_index(&self, index: usize) -> E::BaseField { + match index { + EXIT_CODE_IDX => E::BaseField::from_canonical_u32(self.exit_code & 0xffff), + idx if idx == EXIT_CODE_IDX + 1 => { + E::BaseField::from_canonical_u32((self.exit_code >> 16) & 0xffff) + } + INIT_PC_IDX => E::BaseField::from_canonical_u32(self.init_pc), + INIT_CYCLE_IDX => E::BaseField::from_canonical_u64(self.init_cycle), + END_PC_IDX => E::BaseField::from_canonical_u32(self.end_pc), + END_CYCLE_IDX => E::BaseField::from_canonical_u64(self.end_cycle), + SHARD_ID_IDX => E::BaseField::from_canonical_u32(self.shard_id), + HEAP_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.heap_start_addr), + HEAP_LENGTH_IDX => E::BaseField::from_canonical_u32(self.heap_shard_len), + HINT_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.hint_start_addr), + HINT_LENGTH_IDX => E::BaseField::from_canonical_u32(self.hint_shard_len), + idx if (SHARD_RW_SUM_IDX..(SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2)) + .contains(&idx) => + { + E::BaseField::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) + } + idx if (PUBIO_DIGEST_IDX..(PUBIO_DIGEST_IDX + PUBIO_DIGEST_U16_LIMBS)) + .contains(&idx) => + { + let digest_limb_idx = idx - PUBIO_DIGEST_IDX; + let word_idx = digest_limb_idx / UINT_LIMBS; + let limb_idx = digest_limb_idx % UINT_LIMBS; + E::BaseField::from_canonical_u32( + (self.public_io_digest[word_idx] >> (limb_idx * LIMB_BITS)) & LIMB_MASK, + ) + } + _ => panic!("public value index {index} out of range"), + } } } @@ -169,11 +197,7 @@ impl PublicValues { deserialize = "E::BaseField: DeserializeOwned" ))] pub struct ZKVMProof> { - // TODO preserve in serde only for auxiliary public input - // other raw value can be construct by verifier directly. - pub raw_pi: Vec>, - // the evaluation of raw_pi. - pub pi_evals: Vec, + pub public_values: PublicValues, // each circuit may have multiple proof instances pub chip_proofs: BTreeMap>>, pub witin_commit: >::Commitment, @@ -182,41 +206,19 @@ pub struct ZKVMProof> { impl> ZKVMProof { pub fn new( - raw_pi: Vec>, - pi_evals: Vec, + public_values: PublicValues, chip_proofs: BTreeMap>>, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { Self { - raw_pi, - pi_evals, + public_values, chip_proofs, witin_commit, opening_proof, } } - pub fn pi_evals(raw_pi: &[Vec]) -> Vec { - raw_pi - .iter() - .map(|pv| { - if pv.len() == 1 { - // this is constant poly, and always evaluate to same constant value - E::from(pv[0]) - } else { - // set 0 as placeholder. will be evaluate lazily - // Or the vector is empty, i.e. the constant 0 polynomial. - E::ZERO - } - }) - .collect_vec() - } - - pub fn update_pi_eval(&mut self, idx: usize, v: E) { - self.pi_evals[idx] = v; - } - pub fn num_circuits(&self) -> usize { self.chip_proofs.len() } diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index a773f4de9..929c7850d 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -842,11 +842,6 @@ impl> MainSumcheckProver> MainSumcheckProver { pub witness: Vec>>, pub structural_witness: Vec>>, pub fixed: Vec>>, - pub public_input: Vec>>, - pub pub_io_evals: Vec::BaseField, PB::E>>, + pub pi: Vec::BaseField, PB::E>>, pub num_instances: Vec, pub has_ecc_ops: bool, } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index a477b1c6a..7842374fc 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,7 +26,7 @@ use gkr_iop::{ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, - mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + mle::{ArcMultilinearExtension, MultilinearExtension}, util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; @@ -40,7 +40,6 @@ use std::{ hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, - ops::Index, sync::OnceLock, }; use strum::IntoEnumIterator; @@ -624,7 +623,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { left, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -639,7 +638,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { &right, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -671,7 +670,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -717,7 +716,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -763,7 +762,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { arg_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -964,17 +963,16 @@ Hints: ) where E: LkMultiplicityKey, { - let pub_io_evals = pi - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) - .collect_vec(); - let pi_mles: Vec> = pi - .to_vec::() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(); + let get_circuit_pi_inputs = |circuit_cs: &ConstraintSystem| { + let circuit_pub_io_evals = circuit_cs + .instance + .iter() + .map(|instance| Either::Right(E::from(pi.query_by_index::(instance.0)))) + .collect_vec(); + let circuit_pi_mles = vec![]; + (circuit_pub_io_evals, circuit_pi_mles) + }; + let mut rng = thread_rng(); let challenges = [0u8; 2].map(|_| E::random(&mut rng)); @@ -1000,11 +998,7 @@ Hints: let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = &composed_cs; - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); // skip init table on non-first shard if composed_cs.with_omc_init_only() && !shard_ctx.is_first_shard() { @@ -1061,8 +1055,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, num_rows, challenges, lkm_from_assignments, @@ -1089,12 +1083,12 @@ Hints: &expr.values, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1104,12 +1098,12 @@ Hints: &expr.multiplicity, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1162,11 +1156,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { @@ -1200,12 +1190,12 @@ Hints: ram_type_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1213,12 +1203,12 @@ Hints: w_rlc_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let w_selector_vec = w_selector.get_base_field_vec(); @@ -1268,11 +1258,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1304,12 +1290,12 @@ Hints: ram_type_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1317,12 +1303,12 @@ Hints: r_rlc_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let r_selector_vec = r_selector.get_base_field_vec(); @@ -1345,12 +1331,12 @@ Hints: expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); filter_mle_by_selector_mle(v, r_selector.clone()) @@ -1481,38 +1467,23 @@ Hints: let mut cb = CircuitBuilder::new(&mut cs); let gs_init = GlobalState::initial_global_state(&mut cb).unwrap(); let gs_final = GlobalState::finalize_global_state(&mut cb).unwrap(); + let gs_pub_io_evals = cs + .instance + .iter() + .map(|instance| E::from(pi.query_by_index::(instance.0))) + .collect_vec(); let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); gs_rs.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), - &challenges, - &gs_final, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_final) + .right() + .unwrap(), ); gs_ws.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), - &challenges, - &gs_init, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_init) + .right() + .unwrap(), ); // gs stores { (pc, timestamp) } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 99ad1ac39..f7f35add7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -3,11 +3,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, hal::ProverBackend, }; -use std::{ - collections::{BTreeMap, HashMap}, - marker::PhantomData, - sync::Arc, -}; +use std::{collections::BTreeMap, marker::PhantomData, sync::Arc}; #[cfg(feature = "gpu")] use crate::scheme::gpu::estimate_chip_proof_memory; @@ -17,13 +13,9 @@ use crate::scheme::{ scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; use either::Either; -use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{ - Expression, Instance, - mle::{IntoMLE, MultilinearExtension}, -}; +use multilinear_extensions::{Expression, Instance}; use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ @@ -46,7 +38,7 @@ use crate::{ structs::{TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, }; -type CreateTableProof = (ZKVMChipProof, HashMap, Point); +type CreateTableProof = (ZKVMChipProof, MainSumcheckEvals, Point); pub type ZkVMCpuProver = ZKVMProver, CpuProver>>; @@ -116,11 +108,15 @@ impl< } } - pub fn setup_init_mem(&self, hints: &[u32], public_io: &[u32]) -> crate::e2e::InitMemState { + pub fn setup_init_mem( + &self, + hints: &[u32], + public_io_digest_input: &[u32], + ) -> crate::e2e::InitMemState { let Some(ctx) = self.pk.program_ctx.as_ref() else { panic!("empty program ctx") }; - ctx.setup_init_mem(hints, public_io) + ctx.setup_init_mem(hints, public_io_digest_input) } } @@ -156,18 +152,17 @@ impl< shard_id = shard_ctx.shard_id ) .in_scope(|| { - let raw_pi = pi.to_vec::(); - let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); - let span = entered_span!("commit_to_pi", profiling_1 = true); - // including raw public input to transcript - for v in raw_pi.iter().flatten() { - transcript.append_field_element(v); + // Include transcript-visible public values in canonical circuit order. + // The order must match verifier and recursion verifier exactly. + // TODO deal with vector-based public value to transcript + for (_, circuit_pk) in self.pk.circuit_pks.iter() { + for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance.iter() { + transcript.append_field_element(&pi.query_by_index::(instance_value.0)); + } } - exit_span!(span); - let pi: Vec> = - raw_pi.iter().map(|p| p.to_vec().into_mle()).collect(); + exit_span!(span); // commit to fixed commitment let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); @@ -275,10 +270,6 @@ impl< ]; tracing::debug!("global challenges in prover: {:?}", challenges); - let public_input_span = entered_span!("public_input", profiling_1 = true); - let public_input = self.device.transport_mles(&pi); - exit_span!(public_input_span); - let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); // Phase 1: Build all ChipTasks @@ -291,8 +282,7 @@ impl< &witness_data, fixed_mles, challenges, - public_input, - &pi_evals, + &pi, &circuit_trace_indices, ); exit_span!(build_tasks_span); @@ -307,11 +297,7 @@ impl< // Phase 3: Collect results let collect_results_span = entered_span!("collect_chip_results", profiling_1 = true); - let (chip_proofs, points, evaluations, pi_updates) = - Self::collect_chip_results(results); - for (idx, eval) in pi_updates { - pi_evals[idx] = eval; - } + let (chip_proofs, points, evaluations) = Self::collect_chip_results(results); exit_span!(collect_results_span); exit_span!(main_proofs_span); @@ -335,13 +321,7 @@ impl< }); exit_span!(pcs_opening); - let vm_proof = ZKVMProof::new( - raw_pi, - pi_evals, - chip_proofs, - witin_commit, - mpcs_opening_proof, - ); + let vm_proof = ZKVMProof::new(pi, chip_proofs, witin_commit, mpcs_opening_proof); Ok(vm_proof) }) @@ -389,7 +369,7 @@ impl< let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = unsafe { std::mem::transmute(task.input) }; - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = create_chip_proof_gpu_impl::( task.circuit_name.as_str(), task.pk, @@ -406,7 +386,7 @@ impl< task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -424,14 +404,14 @@ impl< // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = self.create_chip_proof(&task, transcript)?; Ok(ChipTaskResult { task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -523,20 +503,6 @@ impl< } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -546,11 +512,12 @@ impl< gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances.clone(), }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + }, input_opening_point, )) } @@ -567,8 +534,7 @@ impl< witness_data: &PB::PcsData, mut fixed_mles: Vec>>, challenges: [E; 2], - public_input: Vec>>, - pi_evals: &[E], + pi: &PublicValues, circuit_trace_indices: &[Option], ) -> Vec> { // CPU path: eagerly extract witness MLEs from pcs_data @@ -645,18 +611,25 @@ impl< }; let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); + + let circuit_pi = cs + .zkvm_v1_css + .instance + .iter() + .map(|Instance(idx)| Either::Left(pi.query_by_index::(*idx))) + .collect_vec(); + let input_temp: ProofInput<'_, PB> = ProofInput { witness: witness_mle, fixed, structural_witness, - public_input: public_input.clone(), - pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), + pi: circuit_pi, num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; // SAFETY: All Arcs in ProofInput contain 'static data: // - GPU path: `witness` and `structural_witness` are empty vecs (deferred extraction), - // `fixed` and `public_input` originate from `DeviceProvingKey<'static, PB>`. + // `fixed` originates from `DeviceProvingKey<'static, PB>`. // - CPU path: `witness_mle` may borrow non-'static data, but the CPU path always // uses sequential execution (never enters the concurrent scheduler), so the data // remains valid for the lifetime of `build_chip_tasks`'s caller. @@ -718,12 +691,10 @@ impl< BTreeMap>>, Vec>, Vec>>, - HashMap, ) { let mut chip_proofs = BTreeMap::new(); let mut points = Vec::new(); let mut evaluations = Vec::new(); - let mut pi_updates = HashMap::new(); for result in results { tracing::trace!( @@ -735,23 +706,17 @@ impl< if result.has_witness_or_fixed { points.push(result.input_opening_point); evaluations.push(vec![ - result.proof.wits_in_evals.clone(), - result.proof.fixed_in_evals.clone(), + result.opening_evals.wits_in_evals, + result.opening_evals.fixed_in_evals, ]); - } else { - assert!(result.proof.wits_in_evals.is_empty()); - assert!(result.proof.fixed_in_evals.is_empty()); } chip_proofs .entry(result.circuit_idx) .or_insert(vec![]) .push(result.proof); - for (idx, eval) in result.pi_in_evals { - pi_updates.insert(idx, eval); - } } - (chip_proofs, points, evaluations, pi_updates) + (chip_proofs, points, evaluations) } } @@ -883,20 +848,6 @@ where } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -906,11 +857,12 @@ where gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances, }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + }, input_opening_point, )) } diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 438421b2e..c1bd260f4 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -12,14 +12,17 @@ use crate::{ error::ZKVMError, - scheme::{ZKVMChipProof, hal::ProofInput}, + scheme::{ + ZKVMChipProof, + hal::{MainSumcheckEvals, ProofInput}, + }, structs::ProvingKey, }; use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; use p3::field::FieldAlgebra; -use std::{collections::HashMap, sync::OnceLock}; +use std::sync::OnceLock; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -77,8 +80,8 @@ pub struct ChipTaskResult { pub circuit_idx: usize, /// The generated proof pub proof: ZKVMChipProof, - /// Public input evaluations - pub pi_in_evals: HashMap, + /// Prover-only opening evaluations split by witness/fixed/pi domains. + pub opening_evals: MainSumcheckEvals, /// Opening point for this proof pub input_opening_point: Point, /// Whether this circuit has witness or fixed polynomials diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 91a4e4563..2f5fe8ca2 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -34,7 +34,7 @@ use ff_ext::{Instrumented, PoseidonField}; use super::{ PublicValues, - constants::MAX_NUM_VARIABLES, + constants::{MAX_NUM_VARIABLES, SEPTIC_EXTENSION_DEGREE}, prover::ZKVMProver, utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, @@ -52,6 +52,39 @@ use p3::field::FieldAlgebra; use rand::thread_rng; use transcript::{BasicTranscript, Transcript}; +#[test] +fn test_public_values_iter_field_matches_query_order() { + type E = GoldilocksExt2; + + let public_values = PublicValues::new( + 0xABCD_1234, + 0x0800_0000, + 123, + 0x0800_1000, + 456, + 7, + 0x3000_0000, + 64, + 0x2800_0000, + 32, + [0, 1, 2, 3, 4, 5, 6, 7], + std::array::from_fn(|i| (i as u32) + 10), + ); + + let from_iter = public_values + .iter_field::<::BaseField>() + .collect_vec(); + let from_query = (0..PublicValues::flattened_len()) + .map(|i| public_values.query_by_index::(i)) + .collect_vec(); + + assert_eq!( + PublicValues::flattened_len(), + 11 + SEPTIC_EXTENSION_DEGREE * 2 + 16 + ); + assert_eq!(from_iter, from_query); +} + struct TestConfig { pub(crate) reg_id: WitIn, } @@ -211,8 +244,7 @@ fn test_rw_lk_expression_combination() { fixed: vec![], witness: wits_in, structural_witness: structural_in, - public_input: vec![], - pub_io_evals: vec![], + pi: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; @@ -248,13 +280,12 @@ fn test_rw_lk_expression_combination() { { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - verifier + let _ = verifier .verify_chip_proof( name.as_str(), verifier.vk.circuit_vks.get(&name).unwrap(), &proof, - &[], - &[], + &PublicValues::default(), &mut v_transcript, NUM_FANIN, &PointAndEval::default(), @@ -397,7 +428,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index aa2be21fe..8583e2710 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -359,12 +359,6 @@ pub fn build_main_witness< "assert circuit" ); - let pub_io_mles = cs - .instance_openings - .iter() - .map(|instance| input.public_input[instance.0].clone()) - .collect_vec(); - // check all witness size are power of 2 assert!( input @@ -372,7 +366,6 @@ pub fn build_main_witness< .iter() .chain(&input.structural_witness) .chain(&input.fixed) - .chain(&pub_io_mles) .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) ); @@ -387,8 +380,8 @@ pub fn build_main_witness< &input.witness, &input.structural_witness, &input.fixed, - &pub_io_mles, - &input.pub_io_evals, + &[], + &input.pi, challenges, ); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2c49ee658..71b56d290 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,5 +1,5 @@ use either::Either; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, SmallField}; use std::{ iter::{self, once, repeat_n}, marker::PhantomData, @@ -8,11 +8,12 @@ use std::{ #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{ZKVMChipProof, ZKVMProof}; +use super::{PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, + END_CYCLE_IDX, END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, + INIT_PC_IDX, }, scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, @@ -51,6 +52,39 @@ pub struct ZKVMVerifier> { } impl> ZKVMVerifier { + #[allow(clippy::type_complexity)] + fn split_input_opening_evals( + circuit_vk: &VerifyingKey, + proof: &ZKVMChipProof, + ) -> Result<(Vec, Vec), ZKVMError> { + let cs = circuit_vk.get_cs(); + let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { + return Err(ZKVMError::InvalidProof("missing gkr proof".into())); + }; + let Some(last_layer) = gkr_proof.0.last() else { + return Err(ZKVMError::InvalidProof("empty gkr proof layers".into())); + }; + + let evals = &last_layer.main.evals; + let wit_len = cs.num_witin(); + let fixed_len = cs.num_fixed(); + let min_len = wit_len + fixed_len; + if evals.len() < min_len { + return Err(ZKVMError::InvalidProof( + format!( + "insufficient main evals: {} < required {}", + evals.len(), + min_len + ) + .into(), + )); + } + + let wits_in_evals = evals[..wit_len].to_vec(); + let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); + Ok((wits_in_evals, fixed_in_evals)) + } + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } @@ -116,19 +150,34 @@ impl> ZKVMVerifier } // each shard set init cycle = Tracer::SUBCYCLES_PER_INSN // to satisfy initial reads for all prev_cycle = 0 < init_cycle - assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_CYCLE_IDX), + E::BaseField::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN) + ); // check init_pc match prev end_pc if let Some(prev_pc) = prev_pc { - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], prev_pc); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + prev_pc + ); } else { // first chunk, check program entry - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_canonical_u32(self.vk.entry_pc)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + E::BaseField::from_canonical_u32(self.vk.entry_pc) + ); } - let end_pc = vm_proof.pi_evals[END_PC_IDX]; + let end_pc = vm_proof.public_values.query_by_index::(END_PC_IDX); // check memory continuation consistency - let heap_addr_start_u32 = vm_proof.pi_evals[HEAP_START_ADDR_IDX].to_canonical_u64() as u32; - let heap_len= vm_proof.pi_evals[HEAP_LENGTH_IDX].to_canonical_u64() as u32; + let heap_addr_start_u32 = vm_proof + .public_values + .query_by_index::(HEAP_START_ADDR_IDX) + .to_canonical_u64() as u32; + let heap_len = vm_proof + .public_values + .query_by_index::(HEAP_LENGTH_IDX) + .to_canonical_u64() as u32; if let Some(prev_heap_addr_end) = prev_heap_addr_end { assert_eq!(heap_addr_start_u32, prev_heap_addr_end); // TODO check heap addr in prime field within range @@ -165,7 +214,12 @@ impl> ZKVMVerifier let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; - let pi_evals = &vm_proof.pi_evals; + // Global-state expressions are built from compact instance IDs + // (query order), not absolute public-value indices. + let pi = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .map(|idx| E::from(vm_proof.public_values.query_by_index::(idx))) + .collect_vec(); // make sure circuit index of chip proofs are // subset of that of self.vk.circuit_vks @@ -181,33 +235,18 @@ impl> ZKVMVerifier } } - // TODO fix soundness: construct raw public input by ourself and trustless from proof - // including raw public input to transcript - vm_proof - .raw_pi - .iter() - .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + // Include transcript-visible public values in canonical circuit order. + // This must match prover and recursion verifier exactly. + for (_, circuit_vk) in self.vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { + transcript.append_field_element( + &vm_proof.public_values.query_by_index::(instance_value.0), + ); + } + } // check shard id - assert_eq!( - vm_proof.raw_pi[SHARD_ID_IDX], - vec![E::BaseField::from_canonical_usize(shard_id)] - ); - - // verify constant poly(s) evaluation result match - // we can evaluate at this moment because constant always evaluate to same value - // non-constant poly(s) will be verified in respective (table) proof accordingly - izip!(&vm_proof.raw_pi, pi_evals) - .enumerate() - .try_for_each(|(i, (raw, eval))| { - if raw.len() == 1 && E::from(raw[0]) != *eval { - Err(ZKVMError::VerifyError( - format!("{shard_id}th shard pub input on index {i} mismatch {raw:?} != {eval:?}").into(), - )) - } else { - Ok(()) - } - })?; + assert_eq!(vm_proof.public_values.shard_id, shard_id as u32); // write fixed commitment to transcript // TODO check soundness if there is no fixed_commit but got fixed proof? @@ -291,21 +330,6 @@ impl> ZKVMVerifier let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; - // check chip proof is well-formed - if proof.wits_in_evals.len() != circuit_vk.get_cs().num_witin() - || proof.fixed_in_evals.len() != circuit_vk.get_cs().num_fixed() - { - return Err(ZKVMError::InvalidProof( - format!( - "{shard_id}th shard witness/fixed evaluations length mismatch: ({}, {}) != ({}, {})", - proof.wits_in_evals.len(), - proof.fixed_in_evals.len(), - circuit_vk.get_cs().num_witin(), - circuit_vk.get_cs().num_fixed(), - ) - .into(), - )); - } if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() || proof.w_out_evals.len() != circuit_vk.get_cs().num_writes() { @@ -341,44 +365,44 @@ impl> ZKVMVerifier .sum::(); transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64)); - if circuit_vk.get_cs().is_with_lk_table() { - logup_sum -= chip_logup_sum; - } else { - // getting the number of dummy padding item that we used in this opcode circuit - let num_lks = circuit_vk.get_cs().num_lks(); - // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) - * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); - // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding - let num_instance_non_selected = num_instance - * ((1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)) - - (circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0) + 1)); - dummy_table_item_multiplicity += - num_lks * (num_padded_instance + num_instance_non_selected); - - logup_sum += chip_logup_sum; - }; - let (input_opening_point, chip_shard_ec_sum) = self.verify_chip_proof( - circuit_name, - circuit_vk, - proof, - pi_evals, - &vm_proof.raw_pi, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; + + // compute logup_sum padding + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().num_lks(); + // each padding instance contribute to (2^rotation_vars) dummy lookup padding + let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) + * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); + // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding + let num_instance_non_selected = num_instance + * ((1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)) + - (circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0) + 1)); + dummy_table_item_multiplicity += + num_lks * (num_padded_instance + num_instance_non_selected); + + // accumulate logup_sum + logup_sum += chip_logup_sum; + + let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = self + .verify_chip_proof( + circuit_name, + circuit_vk, + proof, + &vm_proof.public_values, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; if circuit_vk.get_cs().num_witin() > 0 { witin_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), + (input_opening_point.clone(), wits_in_evals), )); } if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.fixed_in_evals.clone()), + (input_opening_point.clone(), fixed_in_evals), )); } prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); @@ -435,7 +459,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi, &challenges, &self.vk.initial_global_state_expr, ) @@ -446,7 +470,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi, &challenges, &self.vk.finalize_global_state_expr, ) @@ -478,13 +502,12 @@ impl> ZKVMVerifier _name: &str, circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, - pi: &[E], - raw_pi: &[Vec], + public_values: &PublicValues, transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result<(Point, Option>), ZKVMError> { + ) -> Result<(Point, Option>, Vec, Vec), ZKVMError> { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -669,17 +692,22 @@ impl> ZKVMVerifier }, ] }; + let pi = cs + .instance + .iter() + .map(|instance| E::from(public_values.query_by_index::(instance.0))) + .collect_vec(); + let (wits_in_evals, fixed_in_evals) = Self::split_input_opening_evals(circuit_vk, proof)?; let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), &evals, - pi, - raw_pi, + &pi, challenges, transcript, &selector_ctxs, )?; - Ok((rt, shard_ec_sum)) + Ok((rt, shard_ec_sum, wits_in_evals, fixed_in_evals)) } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..3490c2f2c 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -155,8 +155,8 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } - pub fn instance_openings(&self) -> &[Instance] { - &self.zkvm_v1_css.instance_openings + pub fn instance(&self) -> &[Instance] { + &self.zkvm_v1_css.instance } pub fn has_ecc_ops(&self) -> bool { !self.zkvm_v1_css.ec_final_sum.is_empty() diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index d934127f9..505178944 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -2,7 +2,7 @@ use ceno_emul::{Addr, VM_REG_COUNT, WORD_SIZE}; use ff_ext::ExtensionField; use gkr_iop::error::CircuitBuilderError; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit, PubIORamInitCircuit}; +use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit}; use crate::{ instructions::riscv::constants::UINT_LIMBS, @@ -38,11 +38,12 @@ impl DynVolatileRamTable for HeapTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_heap_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_heap_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, @@ -143,11 +144,12 @@ impl DynVolatileRamTable for HintsTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_hint_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_hint_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, @@ -239,22 +241,4 @@ impl NonVolatileTable for StaticMemTable { pub type StaticMemInitCircuit = NonVolatileRamCircuit>; -#[derive(Clone)] -pub struct PubIOTable; - -impl NonVolatileTable for PubIOTable { - const RAM_TYPE: RAMType = RAMType::Memory; - const V_LIMBS: usize = UINT_LIMBS; - const WRITABLE: bool = false; - - fn name() -> &'static str { - "PubIOTable" - } - - fn len(params: &ProgramParams) -> usize { - params.pubio_len - } -} - -pub type PubIOInitCircuit = PubIORamInitCircuit; pub type LocalFinalCircuit = LocalFinalRamCircuit; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..0afa819be 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,6 +1,4 @@ -use super::ram_impl::{ - LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableInitConfig, -}; +use super::ram_impl::{LocalFinalRAMTableConfig, NonVolatileTableConfigTrait}; use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, @@ -123,58 +121,6 @@ impl< } } -/// PubIORamCircuit initializes and finalizes memory -/// - at fixed addresses, -/// - with content from the public input of proofs. -/// -/// This circuit does not and cannot decide whether the memory is mutable or not. -/// It supports LOAD where the program reads the public input, -/// or STORE where the memory content must equal the public input after execution. -pub struct PubIORamInitCircuit(PhantomData<(E, R)>); - -impl TableCircuit - for PubIORamInitCircuit -{ - type TableConfig = PubIOTableInitConfig; - type FixedInput = [Addr]; - type WitnessInput<'a> = [MemFinalRecord]; - - fn name() -> String { - format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name()) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - cb.set_omc_init_only(); - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) - } - - fn generate_fixed_traces( - config: &Self::TableConfig, - num_fixed: usize, - io_addrs: &[Addr], - ) -> RowMajorMatrix { - // assume returned table is well-formed including padding - config.gen_init_state(num_fixed, io_addrs) - } - - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - num_structural_witin: usize, - _multiplicity: &[HashMap], - final_mem: &[MemFinalRecord], - ) -> Result, ZKVMError> { - // assume returned table is well-formed including padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_mem)?) - } -} - /// - **Dynamic**: The address space is bounded within a specific range, /// though the range itself may be dynamically determined per proof. /// - **Volatile**: The initial values are set to `0` diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 7e5f1293d..78049ecb3 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -14,7 +14,6 @@ use super::{ ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, }; use crate::{ - chip_handler::general::PublicValuesQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, e2e::ShardContext, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, @@ -161,96 +160,6 @@ impl NonVolatileTableConfigTrait< } } -/// define public io -/// init value set by instance -#[derive(Clone, Debug)] -pub struct PubIOTableInitConfig { - addr: Fixed, - phantom: PhantomData, - params: ProgramParams, -} - -impl PubIOTableInitConfig { - pub fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - assert!(!NVRAM::WRITABLE); - let init_v = cb.query_public_io()?; - let addr = cb.create_fixed(|| "addr"); - - let init_table = [ - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![Expression::Fixed(addr)], - init_v.iter().map(|v| v.expr_as_instance()).collect_vec(), - vec![Expression::ZERO], // Initial cycle. - ] - .concat(); - - cb.w_table_record( - || "init_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - init_table, - )?; - - Ok(Self { - addr, - phantom: PhantomData, - params: params.clone(), - }) - } - - /// assign to fixed address - pub fn gen_init_state( - &self, - num_fixed: usize, - io_addrs: &[Addr], - ) -> RowMajorMatrix { - assert!(NVRAM::len(&self.params).is_power_of_two()); - - let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), - num_fixed, - InstancePaddingStrategy::Default, - ); - assert_eq!(init_table.num_padding_instances(), 0); - - init_table - .par_rows_mut() - .zip_eq(io_addrs) - .for_each(|(row, addr)| { - set_fixed_val!(row, self.addr, (*addr as u64).into_f()); - }); - init_table - } - - /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, - _num_witin: usize, - num_structural_witin: usize, - final_mem: &[MemFinalRecord], - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - if final_mem.is_empty() { - return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); - } - assert!(num_structural_witin == 0 || num_structural_witin == 1); - let mut value = Vec::with_capacity(NVRAM::len(&self.params)); - value.par_extend( - (0..NVRAM::len(&self.params)) - .into_par_iter() - .map(|_| F::ONE), - ); - let structural_witness = - RowMajorMatrix::::new_by_values(value, 1, InstancePaddingStrategy::Default); - Ok([RowMajorMatrix::empty(), structural_witness]) - } -} - /// volatile with all init value as 0 /// dynamic address as witin, relied on augment of knowledge to prove address form #[derive(Clone, Debug)] diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..fb59fdbb7 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -651,7 +651,7 @@ mod tests { use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::babybear::BabyBear; use rand::thread_rng; - use std::{ops::Index, sync::Arc}; + use std::sync::Arc; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; @@ -659,8 +659,8 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ - PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, - septic_curve::SepticPoint, verifier::ZKVMVerifier, + PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, + hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, }, structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, @@ -670,7 +670,6 @@ mod tests { gpu::{MultilinearExtensionGpu, get_cuda_hal}, hal::MultilinearPolynomial, }; - use multilinear_extensions::mle::IntoMLE; use p3::field::PrimeField32; type E = BabyBearExt4; @@ -754,25 +753,17 @@ mod tests { .map(|record| record.ec_point.point.clone()) .sum(); - let public_value = PublicValues::new( - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - vec![0], // dummy - global_ec_sum - .x - .iter() - .chain(global_ec_sum.y.iter()) - .map(|fe| fe.as_canonical_u32()) - .collect_vec(), - ); + let mut shard_rw_sum = [0u32; SEPTIC_EXTENSION_DEGREE * 2]; + for (i, fe) in global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .enumerate() + { + shard_rw_sum[i] = fe.as_canonical_u32(); + } + + let public_value = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], shard_rw_sum); // assign witness let witness = ShardRamCircuit::assign_instances( @@ -807,35 +798,26 @@ mod tests { let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); - let pub_io_evals = public_value - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) + let pub_io_evals = pk + .get_cs() + .zkvm_v1_css + .instance + .iter() + .map(|instance| Either::Right(E::from(public_value.query_by_index::(instance.0)))) .collect_vec(); #[cfg(not(feature = "gpu"))] - let (witness_mles, structural_mles, public_input_mles) = { - let public_input_mles = public_value - .to_vec::() - .into_iter() - .map(|v| Arc::new(v.into_mle())) - .collect_vec(); + let (witness_mles, structural_mles) = { ( witness[0].to_mles().into_iter().map(Arc::new).collect(), witness[1].to_mles().into_iter().map(Arc::new).collect(), - public_input_mles, ) }; #[cfg(feature = "gpu")] - let (witness_mles, structural_mles, public_input_mles) = { + let (witness_mles, structural_mles) = { let cuda_hal = get_cuda_hal().unwrap(); let witness_cpu: Vec<_> = witness[0].to_mles(); let structural_cpu: Vec<_> = witness[1].to_mles(); - let public_cpu: Vec<_> = public_value - .to_vec::() - .into_iter() - .map(|v| v.into_mle()) - .collect_vec(); ( witness_cpu .iter() @@ -845,10 +827,6 @@ mod tests { .iter() .map(|v| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, v))) .collect_vec(), - public_cpu - .iter() - .map(|v| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, v))) - .collect_vec(), ) }; @@ -856,8 +834,7 @@ mod tests { witness: witness_mles, structural_witness: structural_mles, fixed: vec![], - public_input: public_input_mles.clone(), - pub_io_evals, + pi: pub_io_evals, num_instances: vec![n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, }; @@ -882,17 +859,12 @@ mod tests { let mut transcript = BasicTranscript::new(b"global chip test"); let verifier = ZKVMVerifier::new(zkvm_vk); - let pi_evals = public_input_mles - .iter() - .map(|mle| mle.evaluate(&point[..mle.num_vars()])) - .collect_vec(); - let (vrf_point, _) = verifier + let (vrf_point, _, _, _) = verifier .verify_chip_proof( "global", &pk.vk, &proof, - &pi_evals, - &public_value.to_vec::(), + &public_value, &mut transcript, 2, &PointAndEval::default(), diff --git a/examples/examples/fibonacci.rs b/examples/examples/fibonacci.rs index ae44e6707..2f413677f 100644 --- a/examples/examples/fibonacci.rs +++ b/examples/examples/fibonacci.rs @@ -12,6 +12,6 @@ fn main() { a = b; b = c; } - // Constrain with public io - ceno_rt::commit(&b); + // Constrain with public io digest. + ceno_rt::commit(&b.to_le_bytes()); } diff --git a/examples/examples/sha256.rs b/examples/examples/sha256.rs index f485d9d29..05dfb0a27 100644 --- a/examples/examples/sha256.rs +++ b/examples/examples/sha256.rs @@ -9,13 +9,7 @@ fn main() { let input: Vec = ceno_rt::read(); let h = Sha256::digest(&input); - let h: [u8; 32] = h.into(); - let h: [u32; 8] = core::array::from_fn(|i| { - let chunk = &h[4 * i..][..4]; - u32::from_be_bytes(chunk.try_into().unwrap()) - }); - + let h_bytes: [u8; 32] = h.into(); // Output the final hash values one by one - ceno_rt::commit(&h); - // debug_print!("{:x}", h[0]); + ceno_rt::commit(&h_bytes); } diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 2b72f251e..10048418e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -44,8 +44,7 @@ impl Chip { + cb.cs.r_table_expressions.len() + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed - + cb.cs.num_witin as usize - + cb.cs.instance_openings.len(), + + cb.cs.num_witin as usize, final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index d84cb4d55..b66d05300 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -1,3 +1,4 @@ +use ff_ext::ExtensionField; use itertools::{Itertools, chain}; use multilinear_extensions::{ Expression, Fixed, Instance, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, WitnessId, @@ -6,8 +7,6 @@ use multilinear_extensions::{ use serde::de::DeserializeOwned; use std::{collections::HashMap, iter::once, marker::PhantomData}; -use ff_ext::ExtensionField; - use crate::{ RAMType, error::CircuitBuilderError, gkr::layer::ROTATION_OPENING_COUNT, selector::SelectorType, tables::LookupTable, @@ -102,7 +101,8 @@ pub struct ConstraintSystem { pub num_fixed: usize, pub fixed_namespace_map: Vec, - pub instance_openings: Vec, + // record which public input index is involving in constraint computation + pub instance: Vec, pub ec_point_exprs: Vec>, pub ec_slope_exprs: Vec>, @@ -175,7 +175,7 @@ impl ConstraintSystem { num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), - instance_openings: vec![], + instance: vec![], ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], @@ -259,25 +259,14 @@ impl ConstraintSystem { f } - pub fn query_instance(&self, idx: usize) -> Result { + pub fn query_instance(&mut self, idx: usize) -> Result { let i = Instance(idx); - Ok(i) - } - - pub fn query_instance_for_openings( - &mut self, - idx: usize, - ) -> Result { - let i = Instance(idx); - assert!( - !self.instance_openings.contains(&i), - "query same pubio idx {idx} mle more than once", + !self.instance.contains(&i), + "query same pubio idx {idx} value more than once", ); - self.instance_openings.push(i); - - // return instance only count - Ok(Instance(self.instance_openings.len() - 1)) + self.instance.push(i); + Ok(Instance(self.instance.len() - 1)) } pub fn rlc_chip_record(&self, items: Vec>) -> Expression { diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b025aa1e4..c7dfc8cad 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -123,7 +123,6 @@ impl GKRCircuit { gkr_proof: GKRProof, out_evals: &[PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -145,7 +144,6 @@ impl GKRCircuit { layer_proof, &mut evaluations, pub_io_evals, - raw_pi, &mut challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a67862df7..c620fbd73 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; use multilinear_extensions::{ - Expression, Instance, StructuralWitIn, ToExpr, + Expression, StructuralWitIn, ToExpr, mle::{Point, PointAndEval}, monomial::Term, }; @@ -75,8 +75,6 @@ pub struct Layer { pub max_expr_degree: usize, /// keep all structural witin which could be evaluated succinctly without PCS pub structural_witins: Vec, - /// instance openings - pub instance_openings: Vec, /// num challenges dedicated to this layer. pub n_challenges: usize, /// Expressions to prove in this layer. For zerocheck and linear layers, @@ -158,7 +156,6 @@ impl Layer { ), expr_names: Vec, structural_witins: Vec, - instance_openings: Vec, ) -> Self { assert_eq!(expr_names.len(), exprs.len(), "there are expr without name"); let max_expr_degree = exprs @@ -178,7 +175,6 @@ impl Layer { n_instance, max_expr_degree, structural_witins, - instance_openings, n_challenges, exprs, exprs_with_selector_out_eval_monomial_form: vec![], @@ -258,7 +254,6 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], challenges: &mut Vec, transcript: &mut Trans, selector_ctxs: &[SelectorContext], @@ -273,7 +268,6 @@ impl Layer { proof, eval_and_dedup_points, pub_io_evals, - raw_pi, challenges, transcript, selector_ctxs, @@ -434,12 +428,23 @@ impl Layer { if let Some(lk_selector) = cb.cs.lk_selector.as_ref() { // process lookup records let evals = Self::dedup_last_selector_evals(lk_selector, &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in (cb + for (idx, (((is_negate, lookup), name), lookup_eval)) in (cb .cs .lk_expressions .iter() - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) + .map(|expr| (false, expr)) + .chain( + cb.cs + .lk_table_expressions + .iter() + .map(|t| (true, &t.multiplicity)), + ) + .chain( + cb.cs + .lk_table_expressions + .iter() + .map(|t| (false, &t.values)), + )) .zip_eq(if cb.cs.lk_table_expressions.is_empty() { Either::Left(cb.cs.lk_expressions_namespace_map.iter()) } else { @@ -454,13 +459,35 @@ impl Layer { .zip_eq(&lookup_evals) .enumerate() { - expressions.push(lookup - cb.cs.chip_record_alpha.clone()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - alpha (padding) - *lookup_eval, - E::BaseField::ONE.expr().into(), - cb.cs.chip_record_alpha.clone().neg().into(), - )); + // Encode lookup constraints in the canonical form: `sel * expression = evaluation`. + // + // Non-negated lookup: + // claim = sel * lookup + (1 - sel) * padding + // => claim - padding = sel * (lookup - padding) + // so we use `expression = lookup - padding` and `evaluation = claim - padding`. + // + // Negated lookup (`-lookup` used by multiplicity path): + // claim - padding = sel * (-lookup - padding) + // => padding - claim = sel * (lookup + padding) + // so we use `expression = lookup + padding` and `evaluation = padding - claim`. + if is_negate { + expressions.push(lookup + cb.cs.chip_record_alpha.clone()); + evals.push(EvalExpression::::Linear( + // evaluation = alpha (padding) - claim * one + *lookup_eval, + E::BaseField::ONE.neg().expr().into(), + cb.cs.chip_record_alpha.clone().into(), + )); + } else { + expressions.push(lookup - cb.cs.chip_record_alpha.clone()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + *lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + )); + }; + expr_names.push(format!("{}/{idx}", name)); } } @@ -495,7 +522,7 @@ impl Layer { } = &cb.cs; let in_eval_expr = (non_zero_expr_len..) - .take(cb.cs.num_witin as usize + cb.cs.num_fixed + cb.cs.instance_openings.len()) + .take(cb.cs.num_witin as usize + cb.cs.num_fixed) .collect_vec(); if rotations.is_empty() { Layer::new( @@ -504,7 +531,7 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - cb.cs.instance_openings.len(), + 0, expressions, n_challenges, in_eval_expr, @@ -512,7 +539,6 @@ impl Layer { ((None, vec![]), 0, 0), expr_names, cb.cs.structural_witins.clone(), - cb.cs.instance_openings.clone(), ) } else { let Some(RotationParams { @@ -529,7 +555,7 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - cb.cs.instance_openings.len(), + 0, expressions, n_challenges, in_eval_expr, @@ -541,7 +567,6 @@ impl Layer { ), expr_names, cb.cs.structural_witins.clone(), - cb.cs.instance_openings.clone(), ) } } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cf1da6df2..cb0b91fa1 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -3,7 +3,7 @@ use itertools::{Itertools, chain, izip}; use multilinear_extensions::{ ChallengeId, Expression, StructuralWitIn, StructuralWitInType, ToExpr, WitnessId, macros::{entered_span, exit_span}, - mle::{IntoMLE, Point}, + mle::Point, monomial::Term, monomialize_expr_to_wit_terms, utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins}, @@ -77,7 +77,6 @@ pub trait ZerocheckLayer { proof: LayerProof, eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -122,13 +121,15 @@ impl ZerocheckLayer for Layer { let sel_expr = sel_type.selector_expr(); let expr = match out_eval { EvalExpression::Linear(_, a, b) => { - assert_eq!( - a.as_ref().clone(), - E::BaseField::ONE.expr(), - "need to extend expression to support a.inverse()" - ); - // sel * exp - b - sel_expr.clone() * expr + b.as_ref().neg().clone() + // See `gkr_iop/src/gkr/layer.rs` for the +/-1 linear-coefficient derivation. + let coeff = a.as_ref(); + if *coeff == E::BaseField::ONE.expr() { + sel_expr.clone() * expr + b.as_ref().neg().clone() + } else if *coeff == E::BaseField::ONE.neg().expr() { + b.as_ref().clone() - sel_expr.clone() * expr + } else { + panic!("unsupported linear eval coefficient: expected +/-1") + } } EvalExpression::Single(_) => sel_expr.clone() * expr, EvalExpression::Zero => Expression::ZERO, @@ -228,7 +229,6 @@ impl ZerocheckLayer for Layer { proof: LayerProof, mut eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -386,24 +386,6 @@ impl ZerocheckLayer for Layer { } } - // check pub-io - // assume public io is tiny vector, so we evaluate it directly without PCS - let pubio_offset = self.n_witin + self.n_fixed; - for (index, instance) in self.instance_openings.iter().enumerate() { - let index = pubio_offset + index; - let poly = raw_pi[instance.0].to_vec().into_mle(); - let expected_eval = poly.evaluate(&in_point[..poly.num_vars()]); - if expected_eval != main_evals[index] { - return Err(BackendError::LayerVerificationFailed( - format!("layer {} pi mismatch", self.name.clone()).into(), - VerifierError::ClaimNotMatch( - format!("{}", expected_eval).into(), - format!("{}", main_evals[index]).into(), - ), - )); - } - } - let got_claim = eval_by_expr_with_instance( &[], &main_evals, diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index 6bbb1cdc7..eb30aa87b 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -425,7 +425,6 @@ impl LayerConstraintSystem { ((None, vec![]), 0, 0), expr_names, vec![], - vec![], ) } else { let Some(RotationParams { @@ -454,7 +453,6 @@ impl LayerConstraintSystem { ), expr_names, vec![], - vec![], ) } }