From 901d8b69dbaa5fef1f4ccca44e77d44073946c5b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 09:57:04 +0800 Subject: [PATCH 1/6] commit pi as hash digest --- Cargo.lock | 1 + ceno_cli/example/src/main.rs | 2 +- ceno_emul/src/lib.rs | 1 + ceno_emul/src/platform.rs | 29 +++++++-------- ceno_emul/src/syscalls.rs | 11 +++++- ceno_host/src/lib.rs | 8 +---- ceno_rt/Cargo.toml | 1 + ceno_rt/ceno_link.x | 10 +----- ceno_rt/memory.x | 3 +- ceno_rt/src/mmio.rs | 41 +++++++++++++++------- ceno_zkvm/src/e2e.rs | 27 +++++++++----- ceno_zkvm/src/instructions/riscv/rv32im.rs | 22 +++++++++++- ceno_zkvm/src/scheme.rs | 3 ++ ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/tables/shard_ram.rs | 1 + examples/examples/fibonacci.rs | 4 +-- examples/examples/sha256.rs | 6 ++-- 17 files changed, 107 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f57867be..2598061ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1171,6 +1171,7 @@ dependencies = [ "getrandom 0.2.16", "getrandom 0.3.2", "serde", + "tiny-keccak", ] [[package]] diff --git a/ceno_cli/example/src/main.rs b/ceno_cli/example/src/main.rs index fc88e6613..aa6f5f655 100644 --- a/ceno_cli/example/src/main.rs +++ b/ceno_cli/example/src/main.rs @@ -27,5 +27,5 @@ fn main() { panic!(); } - ceno_rt::commit(&cnt_primes); + ceno_rt::commit(&cnt_primes.to_le_bytes()); } diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..10ebcf587 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -41,6 +41,7 @@ pub use syscalls::{ }, keccak_permute::{KECCAK_WORDS, KeccakSpec}, phantom::LogPcCycleSpec, + PubIoCommitSpec, secp256k1::{ COORDINATE_WORDS as SECP256K1_COORDINATE_WORDS, SECP256K1_ARG_WORDS, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..b02af5d56 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -13,7 +13,6 @@ use crate::addr::{Addr, RegIdx}; pub struct Platform { pub rom: Range, pub prog_data: Arc>, - pub public_io: Range, pub stack: Range, pub heap: Range, @@ -34,7 +33,7 @@ impl Display for Platform { write!( f, "Platform {{ rom: {:#x}..{:#x}, prog_data: {:#x}..{:#x}, stack: {:#x}..{:#x}, heap: {:#x}..{:#x}, \ - public_io: {:#x}..{:#x}, hints: {:#x}..{:#x}, unsafe_ecall_nop: {} }}", + hints: {:#x}..{:#x}, unsafe_ecall_nop: {} }}", self.rom.start, self.rom.end, prog_data @@ -49,8 +48,6 @@ impl Display for Platform { self.stack.end, self.heap.start, self.heap.end, - self.public_io.start, - self.public_io.end, self.hints.start, self.hints.end, self.unsafe_ecall_nop @@ -81,12 +78,13 @@ impl Display for Platform { // │ STACK (≈128 MB, grows downward) // │ 0x1800_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1800_0000 (stack base / pubio end) +// ├───────────────────────────── 0x1800_0000 (stack base) // │ -// │ PUBLIC I/O (128 MB) -// │ 0x1000_0000 .. 0x1800_0000 +// │ STACK (stack-only memory window, includes reserved low-address area +// │ previously used for PUBLIC I/O) +// │ 0x1000_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1000_0000 (pubio base / rom end) +// ├───────────────────────────── 0x1000_0000 (stack base / rom end) // │ // │ ROM / TEXT / RODATA (128 MB) // │ 0x0800_0000 .. 0x1000_0000 @@ -94,8 +92,7 @@ impl Display for Platform { // └───────────────────────────── 0x8000_0000 (rom base) pub static CENO_PLATFORM: Lazy = Lazy::new(|| Platform { rom: 0x0800_0000..0x1000_0000, // 128 MB - public_io: 0x1000_0000..0x1800_0000, // 128 MB - stack: 0x1800_0000..0x2000_4000, // stack grows downward 128MB, 0x4000 reserved for debug io. + stack: 0x1000_0000..0x2000_4000, // stack grows downward, 0x4000 reserved for debug io. // we make hints start from 0x2800_0000 thus reserve a 128MB gap for debug io // at the end of stack hints: 0x2800_0000..0x3000_0000, // 128 MB @@ -123,10 +120,6 @@ impl Platform { self.stack.contains(&addr) || self.heap.contains(&addr) || self.is_prog_data(addr) } - pub fn is_pub_io(&self, addr: Addr) -> bool { - self.public_io.contains(&addr) - } - pub fn is_hints(&self, addr: Addr) -> bool { self.hints.contains(&addr) } @@ -155,7 +148,7 @@ impl Platform { } pub fn can_write(&self, addr: Addr) -> bool { - self.is_ram(addr) || self.is_pub_io(addr) || self.is_hints(addr) + self.is_ram(addr) || self.is_hints(addr) } // Environment calls. @@ -180,6 +173,11 @@ impl Platform { 0 } + /// The code of ecall PUB_IO_COMMIT. + pub const fn ecall_pub_io_commit() -> u32 { + 1 + } + /// The code of success. pub const fn code_success() -> u32 { 0 @@ -191,7 +189,6 @@ impl Platform { &self.rom, &self.stack, &self.heap, - &self.public_io, &self.hints, ]; ranges.sort_by_key(|r| r.start); diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..1027b6961 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -1,4 +1,4 @@ -use crate::{RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; +use crate::{Platform, RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; use anyhow::Result; pub mod bn254; @@ -30,6 +30,14 @@ pub trait SyscallSpec { const GKR_OUTPUTS: usize = 0; } +pub struct PubIoCommitSpec; +impl SyscallSpec for PubIoCommitSpec { + const NAME: &'static str = "PUB_IO_COMMIT"; + const REG_OPS_COUNT: usize = 0; + const MEM_OPS_COUNT: usize = 0; + const CODE: u32 = Platform::ecall_pub_io_commit(); +} + /// Trace the inputs and effects of a syscall. pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result { match function_code { @@ -49,6 +57,7 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< BN254_FP2_ADD => Ok(bn254::bn254_fp2_add(vm)), BN254_FP2_MUL => Ok(bn254::bn254_fp2_mul(vm)), UINT256_MUL => Ok(uint256::uint256_mul(vm)), + code if code == PubIoCommitSpec::CODE => Ok(SyscallEffects::default()), // phantom syscall PHANTOM_LOG_PC_CYCLE => Ok(phantom::log_pc_cycle(vm)), diff --git a/ceno_host/src/lib.rs b/ceno_host/src/lib.rs index 719288ece..cfe9a9d1d 100644 --- a/ceno_host/src/lib.rs +++ b/ceno_host/src/lib.rs @@ -115,7 +115,7 @@ pub fn run( platform: Platform, elf: &[u8], hints: &CenoStdin, - public_io: Option<&CenoStdin>, + _public_io: Option<&CenoStdin>, ) -> Vec> { let program = Program::load_elf(elf, u32::MAX).unwrap(); let platform = Platform { @@ -124,9 +124,7 @@ pub fn run( }; let hints: Vec = hints.into(); - let pubio: Vec = public_io.map(|c| c.into()).unwrap_or_default(); let hints_range = platform.hints.clone(); - let pubio_range = platform.public_io.clone(); let mut state = VMState::new(platform, Arc::new(program)); @@ -134,10 +132,6 @@ pub fn run( state.init_memory(addr.into(), value); } - for (addr, value) in zip(pubio_range.iter_addresses(), pubio) { - state.init_memory(addr.into(), value); - } - state .iter_until_halt() .collect::>>() diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index f211c53ce..62ff26613 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -14,3 +14,4 @@ ceno_serde = { path = "../ceno_serde" } getrandom = { version = "0.2.15", features = ["custom"], default-features = false } getrandom_v3 = { package = "getrandom", version = "0.3", default-features = false } serde.workspace = true +tiny-keccak.workspace = true diff --git a/ceno_rt/ceno_link.x b/ceno_rt/ceno_link.x index f4c633ba0..0aa7dd928 100644 --- a/ceno_rt/ceno_link.x +++ b/ceno_rt/ceno_link.x @@ -4,11 +4,7 @@ _hints_start = ORIGIN(REGION_HINTS) + 128M; _hints_length = 128M; _lengths_of_hints_start = ORIGIN(REGION_HINTS) + 128M; -_lengths_of_pubio_start = ORIGIN(REGION_PUBIO); -_pubio_start = ORIGIN(REGION_PUBIO); /* 0x20000000 */ -_pubio_end = ORIGIN(REGION_PUBIO) + 128M; /* PUBIO grows upward */ -_pubio_length = 128M; -_stack_start = ORIGIN(REGION_PUBIO) + 256M; /* stack grows downward */ +_stack_start = ORIGIN(REGION_STACK) + LENGTH(REGION_STACK); /* stack grows downward */ SECTIONS { @@ -25,10 +21,6 @@ SECTIONS *(.rodata .rodata.*); } > ROM - .pubio (NOLOAD): ALIGN(4) - { - *(.pubio .pubio.*); - } > STACK_PUBIO .stack (NOLOAD) : ALIGN(4) { diff --git a/ceno_rt/memory.x b/ceno_rt/memory.x index 2f95e9ae4..d42358be7 100644 --- a/ceno_rt/memory.x +++ b/ceno_rt/memory.x @@ -1,7 +1,7 @@ MEMORY { ROM (rx) : ORIGIN = 0x08000000, LENGTH = 128M - STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* PUBIO first 128M, Stack second 128M */ + STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* Stack region */ HINTS (r) : ORIGIN = 0x20000000, LENGTH = 256M /* will shift hint to 0x28000000 with 128M to reserve gap*/ RAM (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* heap/data/bss */ } @@ -9,7 +9,6 @@ MEMORY REGION_ALIAS("REGION_TEXT", ROM); REGION_ALIAS("REGION_RODATA", ROM); -REGION_ALIAS("REGION_PUBIO", STACK_PUBIO); REGION_ALIAS("REGION_STACK", STACK_PUBIO); REGION_ALIAS("REGION_HINTS", HINTS); diff --git a/ceno_rt/src/mmio.rs b/ceno_rt/src/mmio.rs index ead07a7ab..828c5daf8 100644 --- a/ceno_rt/src/mmio.rs +++ b/ceno_rt/src/mmio.rs @@ -3,6 +3,7 @@ use ceno_serde::from_slice; use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; use serde::de::DeserializeOwned; +use tiny_keccak::{Hasher, Keccak}; struct RegionState { next_len_at: *const usize, @@ -58,8 +59,6 @@ impl RegionState { unsafe extern "C" { static _hints_start: u8; static _lengths_of_hints_start: usize; - static _pubio_start: u8; - static _lengths_of_pubio_start: usize; } struct RegionStateCell(UnsafeCell); @@ -77,7 +76,6 @@ impl RegionStateCell { unsafe impl Sync for RegionStateCell {} static HINT_STATE: RegionStateCell = RegionStateCell::new(); -static PUBIO_STATE: RegionStateCell = RegionStateCell::new(); pub fn read_slice<'a>() -> &'a [u8] { unsafe { @@ -100,18 +98,35 @@ where read_owned() } -pub fn pubio_read_slice<'a>() -> &'a [u8] { +#[cfg(target_arch = "riscv32")] +#[inline(always)] +fn syscall_pub_io_commit(digest_u16_limbs: &[u32; 16]) { + // a0 carries a pointer to the digest limbs, a1 carries the limb count. unsafe { - PUBIO_STATE - .with_mut(|state| state.take_slice(&raw const _lengths_of_pubio_start, &_pubio_start)) + core::arch::asm!( + "ecall", + in("a0") digest_u16_limbs.as_ptr(), + in("a1") digest_u16_limbs.len(), + in("t0") 1_u32, + ); } } -/// Read a value from public io, deserialize it, and assert that it matches the given value. -pub fn commit(v: &T) -where - T: DeserializeOwned + core::fmt::Debug + PartialEq, -{ - let expected: T = from_slice(pubio_read_slice()).expect("Deserialised value failed."); - assert_eq!(*v, expected); +#[cfg(not(target_arch = "riscv32"))] +#[inline(always)] +fn syscall_pub_io_commit(_digest_u16_limbs: &[u32; 16]) {} + +fn digest_to_u16_limbs(digest: [u8; 32]) -> [u32; 16] { + core::array::from_fn(|i| u16::from_le_bytes([digest[i * 2], digest[i * 2 + 1]]) as u32) +} + +/// Commit arbitrary public bytes by hashing with Keccak-256 and emitting digest limbs. +pub fn commit(data: &[u8]) { + let mut keccak = Keccak::v256(); + keccak.update(data); + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + + let digest_u16_limbs = digest_to_u16_limbs(digest); + syscall_pub_io_commit(&digest_u16_limbs); } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 0c3d936b7..b777928e3 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -41,6 +41,7 @@ use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; use rustc_hash::FxHashSet; use serde::Serialize; +use tiny_keccak::{Hasher, Keccak}; #[cfg(debug_assertions)] use std::collections::{HashMap, HashSet}; use std::{ @@ -57,6 +58,17 @@ use witness::next_pow2_instance_padding; pub const DEFAULT_MAX_CELLS_PER_SHARDS: u64 = (1 << 30) * 16 / 4 / 2; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 29; pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20; + +pub fn public_io_words_to_digest_u16_limbs(words: &[u32]) -> [u32; 16] { + let mut keccak = Keccak::v256(); + for word in words { + keccak.update(&word.to_le_bytes()); + } + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + core::array::from_fn(|i| u16::from_le_bytes([digest[2 * i], digest[2 * i + 1]]) as u32) +} + // define a relative small number to make first shard handle much less instruction /// The polynomial commitment scheme kind #[derive( @@ -1021,6 +1033,7 @@ pub fn emulate_program<'a>( platform.hints.start, hints_final.len() as u32, io_init.iter().map(|rec| rec.value).collect_vec(), + public_io_words_to_digest_u16_limbs(&io_init.iter().map(|rec| rec.value).collect_vec()), [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); @@ -1139,17 +1152,13 @@ fn setup_platform_inner( heap.start..heap_end as u32 }; - assert!( - pub_io_size.is_power_of_two(), - "pub io size {pub_io_size} must be a power of two" - ); + let _ = pub_io_size; let platform = Platform { rom: program.base_address ..program.base_address + (program.instructions.len() * WORD_SIZE) as u32, prog_data, stack, heap, - public_io: preset.public_io.start..preset.public_io.start + pub_io_size, ..preset }; assert!( @@ -1552,7 +1561,7 @@ pub fn setup_program( multi_prover: MultiProver, ) -> E2EProgramCtx { let static_addrs = init_static_addrs(&program); - let pubio_len = platform.public_io.iter_addresses().len(); + let pubio_len = 0; let program_params = ProgramParams { platform: platform.clone(), program_size: next_pow2_instance_padding(program.instructions.len()), @@ -1561,7 +1570,7 @@ pub fn setup_program( }; let system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); - let io_init = MemPadder::new_mem_records_uninit(platform.public_io.clone(), pubio_len); + let io_init: Vec = vec![]; // Generate fixed traces let zkvm_fixed_traces = generate_fixed_traces( @@ -1635,8 +1644,8 @@ impl E2EProgramCtx { /// Setup init mem state pub fn setup_init_mem(&self, hints: &[u32], public_io: &[u32]) -> InitMemState { - let mut io_init = self.io_init.clone(); - MemPadder::init_mem_records(&mut io_init, public_io); + let _ = public_io; + let io_init = self.io_init.clone(); let hint_init = MemPadder::new_mem_records( self.platform.hints.clone(), hints.len().next_power_of_two(), diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..09ea23733 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -45,7 +45,7 @@ use ceno_emul::{ Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, InsnKind::{self, *}, - KeccakSpec, LogPcCycleSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, + KeccakSpec, LogPcCycleSpec, Platform, PubIoCommitSpec, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, Secp256r1DoubleSpec, Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, StepRecord, SyscallSpec, Uint256MulSpec, Word, @@ -73,6 +73,7 @@ use strum::{EnumCount, IntoEnumIterator}; pub mod mmu; const ECALL_HALT: u32 = Platform::ecall_halt(); +const ECALL_PUB_IO_COMMIT: u32 = Platform::ecall_pub_io_commit(); pub struct Rv32imConfig { // ALU Opcodes. @@ -134,6 +135,8 @@ pub struct Rv32imConfig { // Ecall Opcodes pub halt_config: as Instruction>::InstructionConfig, + pub pubio_commit_config: + as Instruction>::InstructionConfig, pub keccak_config: as Instruction>::InstructionConfig, pub sha_extend_config: as Instruction>::InstructionConfig, pub bn254_add_config: @@ -355,6 +358,8 @@ impl Rv32imConfig { }}; } let halt_config = register_ecall_circuit!(HaltInstruction, ecall_cells_map); + let pubio_commit_config = + register_ecall_circuit!(LargeEcallDummy, ecall_cells_map); // Keccak precompile is a known hotspot for peak memory. // Its heavy read/write/LK activity inflates tower-witness usage, causing @@ -468,6 +473,7 @@ impl Rv32imConfig { lb_config, // ecall opcodes halt_config, + pubio_commit_config, keccak_config, sha_extend_config, bn254_add_config, @@ -562,6 +568,10 @@ impl Rv32imConfig { // system fixed.register_opcode_circuit::>(cs, &self.halt_config); + fixed.register_opcode_circuit::>( + cs, + &self.pubio_commit_config, + ); fixed.register_opcode_circuit::>(cs, &self.keccak_config); fixed.register_opcode_circuit::>(cs, &self.sha_extend_config); fixed.register_opcode_circuit::>>( @@ -650,6 +660,7 @@ impl Rv32imConfig { } log_ecall!("HALT", ECALL_HALT); + log_ecall!("PUB_IO_COMMIT", ECALL_PUB_IO_COMMIT); log_ecall!("KECCAK", KeccakSpec::CODE); log_ecall!("bn254_add_records", Bn254AddSpec::CODE); log_ecall!("bn254_double_records", Bn254DoubleSpec::CODE); @@ -761,6 +772,11 @@ impl Rv32imConfig { // ecall / halt assign_ecall!(HaltInstruction, halt_config, ECALL_HALT); + assign_ecall!( + LargeEcallDummy, + pubio_commit_config, + ECALL_PUB_IO_COMMIT + ); assign_ecall!(KeccakInstruction, keccak_config, KeccakSpec::CODE); assign_ecall!( WeierstrassAddAssignInstruction>, @@ -1042,6 +1058,10 @@ impl Rv32imConfig { .ecall_cells_map .get(&HaltInstruction::::name()) .expect("unable to find name"), + ECALL_PUB_IO_COMMIT => *self + .ecall_cells_map + .get(&LargeEcallDummy::::name()) + .expect("unable to find name"), KeccakSpec::CODE => *self .ecall_cells_map .get(&KeccakInstruction::::name()) diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index b105eb7f9..50475fe0f 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -87,6 +87,7 @@ pub struct PublicValues { pub hint_start_addr: u32, pub hint_shard_len: u32, pub public_io: Vec, + pub pubio_digest: [u32; 16], pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } @@ -104,6 +105,7 @@ impl PublicValues { hint_start_addr: u32, hint_shard_len: u32, public_io: Vec, + pubio_digest: [u32; 16], shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { Self { @@ -118,6 +120,7 @@ impl PublicValues { hint_start_addr, hint_shard_len, public_io, + pubio_digest, shard_rw_sum, } } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 763c3bb4e..d8e1c9029 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -396,7 +396,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 16], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index f9caf5513..27df29e90 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -775,6 +775,7 @@ mod tests { 0, 0, vec![0], // dummy + [0; 16], shard_rw_sum, ); diff --git a/examples/examples/fibonacci.rs b/examples/examples/fibonacci.rs index ae44e6707..2f413677f 100644 --- a/examples/examples/fibonacci.rs +++ b/examples/examples/fibonacci.rs @@ -12,6 +12,6 @@ fn main() { a = b; b = c; } - // Constrain with public io - ceno_rt::commit(&b); + // Constrain with public io digest. + ceno_rt::commit(&b.to_le_bytes()); } diff --git a/examples/examples/sha256.rs b/examples/examples/sha256.rs index f485d9d29..d9239b536 100644 --- a/examples/examples/sha256.rs +++ b/examples/examples/sha256.rs @@ -9,13 +9,13 @@ fn main() { let input: Vec = ceno_rt::read(); let h = Sha256::digest(&input); - let h: [u8; 32] = h.into(); + let h_bytes: [u8; 32] = h.into(); let h: [u32; 8] = core::array::from_fn(|i| { - let chunk = &h[4 * i..][..4]; + let chunk = &h_bytes[4 * i..][..4]; u32::from_be_bytes(chunk.try_into().unwrap()) }); // Output the final hash values one by one - ceno_rt::commit(&h); + ceno_rt::commit(&h_bytes); // debug_print!("{:x}", h[0]); } From 274aabcbe25a6156e6166b7df1ce67ffec4310a6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 15:17:52 +0800 Subject: [PATCH 2/6] pub io constrain via digest; pub io circuit cleanup --- Cargo.lock | 5 +- ceno_cli/src/commands/common_args/ceno.rs | 10 +- ceno_cli/src/sdk.rs | 12 +- ceno_emul/src/platform.rs | 5 - ceno_emul/src/syscalls.rs | 11 +- ceno_emul/src/syscalls/pubio_commit.rs | 28 +++ ceno_rt/Cargo.toml | 1 + ceno_rt/src/mmio.rs | 34 ++-- ceno_zkvm/src/chip_handler/general.rs | 16 +- ceno_zkvm/src/e2e.rs | 69 ++++---- ceno_zkvm/src/instructions/riscv/constants.rs | 3 + ceno_zkvm/src/instructions/riscv/ecall.rs | 2 + .../instructions/riscv/ecall/pubio_commit.rs | 163 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/insn_base.rs | 15 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 23 ++- .../src/instructions/riscv/rv32im/mmu.rs | 32 +--- ceno_zkvm/src/precompiles/mod.rs | 2 + ceno_zkvm/src/precompiles/pubio_commit.rs | 39 +++++ ceno_zkvm/src/scheme.rs | 19 +- ceno_zkvm/src/scheme/prover.rs | 8 +- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/tables/ram.rs | 20 +-- ceno_zkvm/src/tables/shard_ram.rs | 2 +- 23 files changed, 375 insertions(+), 146 deletions(-) create mode 100644 ceno_emul/src/syscalls/pubio_commit.rs create mode 100644 ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs create mode 100644 ceno_zkvm/src/precompiles/pubio_commit.rs diff --git a/Cargo.lock b/Cargo.lock index 2598061ee..1a55db8c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1061,7 +1061,7 @@ dependencies = [ [[package]] name = "ceno_crypto_primitives" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#6e90bc85bceefe09003f84ca5cc1afbe2911a2db" dependencies = [ "ceno_syscall", "elliptic-curve", @@ -1168,6 +1168,7 @@ name = "ceno_rt" version = "0.1.0" dependencies = [ "ceno_serde", + "ceno_syscall", "getrandom 0.2.16", "getrandom 0.3.2", "serde", @@ -1194,7 +1195,7 @@ dependencies = [ [[package]] name = "ceno_syscall" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#b79232a6fc80c799f584380273fd6d3055b36808" +source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#6e90bc85bceefe09003f84ca5cc1afbe2911a2db" [[package]] name = "ceno_zkvm" diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 584702bd4..75be1322b 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -367,14 +367,16 @@ fn run_elf_inner< options.max_cycle_per_shard, ); - let public_io = options + let public_io_digest_input = options .read_public_io() .context("failed to read public io")?; + let public_io_digest = public_io_words_to_digest_words(&public_io_digest_input); + tracing::debug!("public io digest words: {:?}", public_io_digest); let public_io_size = options.public_io_size; assert!( - public_io.len() <= public_io_size as usize / WORD_SIZE, + public_io_digest_input.len() <= public_io_size as usize / WORD_SIZE, "require pub io length {} < max public_io_size {}", - public_io.len(), + public_io_digest_input.len(), public_io_size as usize / WORD_SIZE ); @@ -416,7 +418,7 @@ fn run_elf_inner< platform, multi_prover, &hints, - &public_io, + &public_io_digest_input, options.max_steps, checkpoint, options.shard_id.map(|v| v as usize), diff --git a/ceno_cli/src/sdk.rs b/ceno_cli/src/sdk.rs index cbcb601b9..ce31fcd8b 100644 --- a/ceno_cli/src/sdk.rs +++ b/ceno_cli/src/sdk.rs @@ -153,8 +153,16 @@ where shard_id: Option, ) -> Vec> { if let Some(zkvm_prover) = self.zkvm_prover.as_ref() { - let init_full_mem = zkvm_prover.setup_init_mem(&Vec::from(&hints), &Vec::from(&pub_io)); - run_e2e_proof::(zkvm_prover, &init_full_mem, max_steps, false, shard_id) + let public_io_words = Vec::from(&pub_io); + let init_full_mem = zkvm_prover.setup_init_mem(&Vec::from(&hints), &public_io_words); + run_e2e_proof::( + zkvm_prover, + &init_full_mem, + &public_io_words, + max_steps, + false, + shard_id, + ) } else { panic!("ZKVMProver is not initialized") } diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index b02af5d56..bfcce5cff 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -173,11 +173,6 @@ impl Platform { 0 } - /// The code of ecall PUB_IO_COMMIT. - pub const fn ecall_pub_io_commit() -> u32 { - 1 - } - /// The code of success. pub const fn code_success() -> u32 { 0 diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 1027b6961..547412bc7 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -1,9 +1,10 @@ -use crate::{Platform, RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; +use crate::{RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; use anyhow::Result; pub mod bn254; pub mod keccak_permute; pub mod phantom; +pub mod pubio_commit; pub mod secp256k1; pub(crate) mod secp256r1; pub mod sha256; @@ -33,9 +34,9 @@ pub trait SyscallSpec { pub struct PubIoCommitSpec; impl SyscallSpec for PubIoCommitSpec { const NAME: &'static str = "PUB_IO_COMMIT"; - const REG_OPS_COUNT: usize = 0; - const MEM_OPS_COUNT: usize = 0; - const CODE: u32 = Platform::ecall_pub_io_commit(); + const REG_OPS_COUNT: usize = 1; + const MEM_OPS_COUNT: usize = 8; + const CODE: u32 = ceno_syscall::PUB_IO_COMMIT; } /// Trace the inputs and effects of a syscall. @@ -57,7 +58,7 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< BN254_FP2_ADD => Ok(bn254::bn254_fp2_add(vm)), BN254_FP2_MUL => Ok(bn254::bn254_fp2_mul(vm)), UINT256_MUL => Ok(uint256::uint256_mul(vm)), - code if code == PubIoCommitSpec::CODE => Ok(SyscallEffects::default()), + code if code == PubIoCommitSpec::CODE => Ok(pubio_commit::pubio_commit(vm)), // phantom syscall PHANTOM_LOG_PC_CYCLE => Ok(phantom::log_pc_cycle(vm)), diff --git a/ceno_emul/src/syscalls/pubio_commit.rs b/ceno_emul/src/syscalls/pubio_commit.rs new file mode 100644 index 000000000..e182383ee --- /dev/null +++ b/ceno_emul/src/syscalls/pubio_commit.rs @@ -0,0 +1,28 @@ +use crate::{Change, EmuContext, Platform, Tracer, VMState, WriteOp, utils::MemoryView}; + +use super::{PubIoCommitSpec, SyscallEffects, SyscallSpec, SyscallWitness}; + +const PUBIO_COMMIT_WORDS: usize = 8; + +/// Trace the PUB_IO_COMMIT syscall by reading 8 digest words from guest memory. +pub fn pubio_commit(vm: &VMState) -> SyscallEffects { + let digest_ptr = vm.peek_register(Platform::reg_arg0()); + + let reg_ops = vec![WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(digest_ptr, digest_ptr), + 0, + )]; + + let digest_view = MemoryView::<_, PUBIO_COMMIT_WORDS>::new(vm, digest_ptr); + let mem_ops = digest_view.mem_ops().to_vec(); + + assert_eq!(mem_ops.len(), PubIoCommitSpec::MEM_OPS_COUNT); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} + + + diff --git a/ceno_rt/Cargo.toml b/ceno_rt/Cargo.toml index 62ff26613..83b25cc3e 100644 --- a/ceno_rt/Cargo.toml +++ b/ceno_rt/Cargo.toml @@ -11,6 +11,7 @@ version = "0.1.0" [dependencies] ceno_serde = { path = "../ceno_serde" } +ceno_syscall.workspace = true getrandom = { version = "0.2.15", features = ["custom"], default-features = false } getrandom_v3 = { package = "getrandom", version = "0.3", default-features = false } serde.workspace = true diff --git a/ceno_rt/src/mmio.rs b/ceno_rt/src/mmio.rs index 828c5daf8..62c4a847b 100644 --- a/ceno_rt/src/mmio.rs +++ b/ceno_rt/src/mmio.rs @@ -2,6 +2,7 @@ use ceno_serde::from_slice; use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; +use ceno_syscall::syscall_pub_io_commit; use serde::de::DeserializeOwned; use tiny_keccak::{Hasher, Keccak}; @@ -98,26 +99,15 @@ where read_owned() } -#[cfg(target_arch = "riscv32")] -#[inline(always)] -fn syscall_pub_io_commit(digest_u16_limbs: &[u32; 16]) { - // a0 carries a pointer to the digest limbs, a1 carries the limb count. - unsafe { - core::arch::asm!( - "ecall", - in("a0") digest_u16_limbs.as_ptr(), - in("a1") digest_u16_limbs.len(), - in("t0") 1_u32, - ); - } -} - -#[cfg(not(target_arch = "riscv32"))] -#[inline(always)] -fn syscall_pub_io_commit(_digest_u16_limbs: &[u32; 16]) {} - -fn digest_to_u16_limbs(digest: [u8; 32]) -> [u32; 16] { - core::array::from_fn(|i| u16::from_le_bytes([digest[i * 2], digest[i * 2 + 1]]) as u32) +fn digest_to_words(digest: [u8; 32]) -> [u32; 8] { + core::array::from_fn(|i| { + u32::from_le_bytes([ + digest[i * 4], + digest[i * 4 + 1], + digest[i * 4 + 2], + digest[i * 4 + 3], + ]) + }) } /// Commit arbitrary public bytes by hashing with Keccak-256 and emitting digest limbs. @@ -127,6 +117,6 @@ pub fn commit(data: &[u8]) { let mut digest = [0u8; 32]; keccak.finalize(&mut digest); - let digest_u16_limbs = digest_to_u16_limbs(digest); - syscall_pub_io_commit(&digest_u16_limbs); + let digest_words = digest_to_words(digest); + syscall_pub_io_commit(&digest_words); } diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bea259603..8606baabd 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -6,7 +6,7 @@ use crate::{ instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + PUBIO_DIGEST_IDX, PUBIO_DIGEST_U16_LIMBS, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, }, scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, @@ -25,6 +25,9 @@ pub trait PublicValuesQuery { fn query_end_cycle(&mut self) -> Result; fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + fn query_public_io_digest( + &mut self, + ) -> Result<[Instance; PUBIO_DIGEST_U16_LIMBS], CircuitBuilderError>; #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; fn query_heap_start_addr(&mut self) -> Result; @@ -91,6 +94,17 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { ]) } + fn query_public_io_digest( + &mut self, + ) -> Result<[Instance; PUBIO_DIGEST_U16_LIMBS], CircuitBuilderError> { + let limbs = (0..PUBIO_DIGEST_U16_LIMBS) + .map(|i| self.cs.query_instance(PUBIO_DIGEST_IDX + i)) + .collect::, _>>()?; + Ok(limbs + .try_into() + .expect("pubio digest instance limb count must be fixed")) + } + fn query_shard_id(&mut self) -> Result { self.cs.query_instance(SHARD_ID_IDX) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index b777928e3..91f7106db 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -59,14 +59,29 @@ pub const DEFAULT_MAX_CELLS_PER_SHARDS: u64 = (1 << 30) * 16 / 4 / 2; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 29; pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20; -pub fn public_io_words_to_digest_u16_limbs(words: &[u32]) -> [u32; 16] { +pub fn public_io_words_to_digest_words(words: &[u32]) -> [u32; 8] { let mut keccak = Keccak::v256(); for word in words { keccak.update(&word.to_le_bytes()); } let mut digest = [0u8; 32]; keccak.finalize(&mut digest); - core::array::from_fn(|i| u16::from_le_bytes([digest[2 * i], digest[2 * i + 1]]) as u32) + #[cfg(target_endian = "little")] + { + // Reinterpret Keccak digest bytes as 8 little-endian u32 words. + unsafe { core::mem::transmute::<[u8; 32], [u32; 8]>(digest) } + } + #[cfg(not(target_endian = "little"))] + { + core::array::from_fn(|i| { + u32::from_le_bytes([ + digest[i * 4], + digest[i * 4 + 1], + digest[i * 4 + 2], + digest[i * 4 + 3], + ]) + }) + } } // define a relative small number to make first shard handle much less instruction @@ -782,7 +797,7 @@ impl StepReplay { ) -> Self { let mut vm = VMState::new_with_tracer_config(platform, program, FullTracerConfig { max_step_shard }); - for record in chain!(init_mem_state.hints.iter(), init_mem_state.io.iter()) { + for record in init_mem_state.hints.iter() { vm.init_memory(record.addr.into(), record.value); } StepReplay { @@ -840,13 +855,14 @@ pub fn emulate_program<'a>( program: Arc, max_steps: usize, init_mem_state: &InitMemState, + public_io_digest_input: &[u32], platform: &Platform, multi_prover: &MultiProver, step_cell_extractor: Arc, ) -> EmulationResult<'a> { let InitMemState { mem: mem_init, - io: io_init, + io: _, reg: reg_init, hints: hints_init, stack: _, @@ -865,7 +881,7 @@ pub fn emulate_program<'a>( }); info_span!("[ceno] emulator.init_mem").in_scope(|| { - for record in chain!(hints_init, io_init) { + for record in hints_init { vm.init_memory(record.addr.into(), record.value); } }); @@ -949,17 +965,8 @@ pub fn emulate_program<'a>( }) .collect_vec(); - // Find the final public IO cycles. - let io_final = io_init - .iter() - .map(|rec| MemFinalRecord { - ram_type: RAMType::Memory, - addr: rec.addr, - value: rec.value, - init_value: rec.value, - cycle: final_access.cycle(rec.addr.into()), - }) - .collect_vec(); + // Legacy public-io memory init is removed. + let io_final: Vec = vec![]; // Find the final hints IO cycles. let hints_final = hints_init @@ -1032,8 +1039,8 @@ pub fn emulate_program<'a>( heap_final.len() as u32, platform.hints.start, hints_final.len() as u32, - io_init.iter().map(|rec| rec.value).collect_vec(), - public_io_words_to_digest_u16_limbs(&io_init.iter().map(|rec| rec.value).collect_vec()), + vec![], + public_io_words_to_digest_words(public_io_digest_input), [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); @@ -1221,7 +1228,6 @@ pub fn generate_fixed_traces( system_config: &ConstraintSystemConfig, reg_init: &[MemInitRecord], static_mem_init: &[MemInitRecord], - io_addrs: &[Addr], program: &Program, ) -> ZKVMFixedTraces { let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -1240,7 +1246,6 @@ pub fn generate_fixed_traces( &mut zkvm_fixed_traces, reg_init, static_mem_init, - io_addrs, ); system_config .dummy_config @@ -1415,7 +1420,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &pi, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, &emul_result.final_mem_state.stack, ) .unwrap(); @@ -1430,7 +1434,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &[], &[], &[], - &[], ) .unwrap(); } @@ -1460,7 +1463,6 @@ pub fn generate_witness<'a, E: ExtensionField>( &pi, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, &emul_result.final_mem_state.hints, &emul_result.final_mem_state.stack, &emul_result.final_mem_state.heap, @@ -1532,7 +1534,6 @@ pub struct E2EProgramCtx { pub pubio_len: usize, pub system_config: ConstraintSystemConfig, pub reg_init: Vec, - pub io_init: Vec, pub zkvm_fixed_traces: ZKVMFixedTraces, } @@ -1570,14 +1571,11 @@ pub fn setup_program( }; let system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); - let io_init: Vec = vec![]; - // Generate fixed traces let zkvm_fixed_traces = generate_fixed_traces( &system_config, ®_init, &static_addrs, - &io_init.iter().map(|rec| rec.addr).collect_vec(), &program, ); @@ -1589,7 +1587,6 @@ pub fn setup_program( pubio_len, system_config, reg_init, - io_init, zkvm_fixed_traces, } } @@ -1643,9 +1640,8 @@ impl E2EProgramCtx { } /// Setup init mem state - pub fn setup_init_mem(&self, hints: &[u32], public_io: &[u32]) -> InitMemState { - let _ = public_io; - let io_init = self.io_init.clone(); + pub fn setup_init_mem(&self, hints: &[u32], public_io_digest_input: &[u32]) -> InitMemState { + let _ = public_io_digest_input; let hint_init = MemPadder::new_mem_records( self.platform.hints.clone(), hints.len().next_power_of_two(), @@ -1655,7 +1651,7 @@ impl E2EProgramCtx { InitMemState { mem: self.static_addrs.clone(), reg: self.reg_init.clone(), - io: io_init, + io: vec![], hints: hint_init, // stack/heap both init value 0 and range is dynamic stack: vec![], @@ -1687,7 +1683,7 @@ pub fn run_e2e_with_checkpoint< platform: Platform, multi_prover: MultiProver, hints: &[u32], - public_io: &[u32], + public_io_digest_input: &[u32], max_steps: usize, checkpoint: Checkpoint, // for debug purpose @@ -1706,12 +1702,13 @@ pub fn run_e2e_with_checkpoint< let prover = ZKVMProver::new(pk.into(), device); let start = std::time::Instant::now(); - let init_full_mem = prover.setup_init_mem(hints, public_io); + let init_full_mem = prover.setup_init_mem(hints, public_io_digest_input); tracing::debug!("setup_init_mem done in {:?}", start.elapsed()); // Generate witness let is_mock_proving = std::env::var("MOCK_PROVING").is_ok(); if let Checkpoint::PrepE2EProving = checkpoint { + let public_io_digest_input_owned = public_io_digest_input.to_vec(); return E2ECheckpointResult { proofs: None, vk: Some(vk), @@ -1719,6 +1716,7 @@ pub fn run_e2e_with_checkpoint< _ = run_e2e_proof::( &prover, &init_full_mem, + &public_io_digest_input_owned, max_steps, is_mock_proving, target_shard_id, @@ -1736,6 +1734,7 @@ pub fn run_e2e_with_checkpoint< prover.pk.program_ctx.as_ref().unwrap().program.clone(), max_steps, &init_full_mem, + public_io_digest_input, &prover.pk.program_ctx.as_ref().unwrap().platform, &prover.pk.program_ctx.as_ref().unwrap().multi_prover, step_cell_extractor, @@ -1816,6 +1815,7 @@ pub fn run_e2e_proof< >( prover: &ZKVMProver, init_full_mem: &InitMemState, + public_io_digest_input: &[u32], max_steps: usize, is_mock_proving: bool, // for debug purpose @@ -1829,6 +1829,7 @@ pub fn run_e2e_proof< ctx.program.clone(), max_steps, init_full_mem, + public_io_digest_input, &ctx.platform, &ctx.multi_prover, step_cell_extractor, diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 2ac17f528..469bde7ac 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,4 +1,5 @@ use crate::uint::UIntLimbs; +use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; @@ -18,6 +19,8 @@ pub const HINT_START_ADDR_IDX: usize = HEAP_LENGTH_IDX + 1; pub const HINT_LENGTH_IDX: usize = HINT_START_ADDR_IDX + 1; pub const SHARD_RW_SUM_IDX: usize = HINT_LENGTH_IDX + 1; +pub const PUBIO_DIGEST_IDX: usize = SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2; +pub const PUBIO_DIGEST_U16_LIMBS: usize = 8 * UINT_LIMBS; /// vector-based public value, id start from 0 pub const PUBLIC_IO_IDX: usize = 0; diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 84a836bf3..f286a8173 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -3,6 +3,7 @@ mod fptower_fp2_add; mod fptower_fp2_mul; mod halt; mod keccak; +mod pubio_commit; mod sha_extend; mod uint256; mod weierstrass_add; @@ -13,6 +14,7 @@ pub use fptower_fp::{FpAddInstruction, FpMulInstruction}; pub use fptower_fp2_add::Fp2AddInstruction; pub use fptower_fp2_mul::Fp2MulInstruction; pub use keccak::KeccakInstruction; +pub use pubio_commit::PubIoCommitInstruction; pub use sha_extend::ShaExtendInstruction; pub use uint256::{Secp256k1InvInstruction, Secp256r1InvInstruction, Uint256MulInstruction}; pub use weierstrass_add::WeierstrassAddAssignInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs new file mode 100644 index 000000000..f3e6509ee --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs @@ -0,0 +1,163 @@ +use std::marker::PhantomData; + +use ceno_emul::{Change, InsnKind, Platform, PubIoCommitSpec, StepRecord, SyscallSpec, WORD_SIZE, WriteOp}; +use ff_ext::ExtensionField; +use multilinear_extensions::ToExpr; +use p3::field::FieldAlgebra; + +use crate::{ + chip_handler::general::InstFetch, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{LIMB_BITS, LIMB_MASK, MEM_BITS, UInt}, + ecall_base::OpFixedRS, + insn_base::{MemAddr, ReadMEM, StateInOut}, + }, + }, + precompiles::{PUBIO_COMMIT_WORDS, PubioCommitLayout}, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; + +#[derive(Debug)] +pub struct EcallPubioCommitConfig { + vm_state: StateInOut, + ecall_id: OpFixedRS, + digest_ptr: (OpFixedRS, MemAddr), + mem_read: [ReadMEM; PUBIO_COMMIT_WORDS], +} + +pub struct PubIoCommitInstruction(PhantomData); + +impl Instruction for PubIoCommitInstruction { + type InstructionConfig = EcallPubioCommitConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } + + fn name() -> String { + "Ecall_PubioCommit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + let vm_state = StateInOut::construct_circuit(cb, false)?; + let syscall_code = PubIoCommitSpec::CODE; + + let ecall_id = OpFixedRS::<_, { Platform::reg_ecall() }, false>::construct_circuit( + cb, + UInt::from_const_unchecked(vec![ + syscall_code & LIMB_MASK, + (syscall_code >> LIMB_BITS) & LIMB_MASK, + ]) + .register_expr(), + vm_state.ts, + )?; + + let digest_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; + let digest_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + cb, + digest_ptr_value.uint_unaligned().register_expr(), + vm_state.ts, + )?; + + cb.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + InsnKind::ECALL.into(), + None, + 0.into(), + 0.into(), + 0.into(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), + ))?; + + let layout = PubioCommitLayout::construct_circuit(cb)?; + let mem_read: [ReadMEM; PUBIO_COMMIT_WORDS] = (0..PUBIO_COMMIT_WORDS) + .map(|i| { + ReadMEM::construct_circuit( + cb, + digest_ptr.prev_value.as_ref().unwrap().value() + + E::BaseField::from_canonical_u32((i * WORD_SIZE) as u32).expr(), + layout.digest_words[i].clone(), + vm_state.ts, + ) + }) + .collect::, _>>()? + .try_into() + .expect("pubio read width is fixed"); + + Ok(EcallPubioCommitConfig { + vm_state, + ecall_id, + digest_ptr: (digest_ptr, digest_ptr_value), + mem_read, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let syscall_code = PubIoCommitSpec::CODE; + let ops = step.syscall().expect("syscall step"); + assert_eq!(ops.reg_ops.len(), 1, "PUB_IO_COMMIT expects 1 reg op"); + assert_eq!( + ops.mem_ops.len(), + PUBIO_COMMIT_WORDS, + "PUB_IO_COMMIT expects {} mem ops", + PUBIO_COMMIT_WORDS + ); + + config.vm_state.assign_instance(instance, shard_ctx, step)?; + + config.ecall_id.assign_op( + instance, + shard_ctx, + lk_multiplicity, + step.cycle(), + &WriteOp::new_register_op( + Platform::reg_ecall(), + Change::new(syscall_code, syscall_code), + step.rs1().unwrap().previous_cycle, + ), + )?; + + config + .digest_ptr + .1 + .assign_instance(instance, lk_multiplicity, ops.reg_ops[0].value.after)?; + config.digest_ptr.0.assign_op( + instance, + shard_ctx, + lk_multiplicity, + step.cycle(), + &ops.reg_ops[0], + )?; + + for (reader, op) in config.mem_read.iter().zip(&ops.mem_ops) { + reader.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; + } + + lk_multiplicity.fetch(step.pc().before.0); + Ok(()) + } +} + + + + + + diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..d1c0287d4 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -335,9 +335,20 @@ impl ReadMEM { step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + } + + pub fn assign_op( + &self, + instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + cycle: Cycle, + op: &WriteOp, + ) -> Result<(), ZKVMError> { let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let shard_cycle = step.cycle() - current_shard_offset_cycle; + let shard_cycle = cycle - current_shard_offset_cycle; // Memory state set_val!(instance, self.prev_ts, shard_prev_cycle); @@ -353,7 +364,7 @@ impl ReadMEM { RAMType::Memory, op.addr, op.addr.baddr().0 as u64, - step.cycle() + Tracer::SUBCYCLE_MEM, + cycle + Tracer::SUBCYCLE_MEM, op.previous_cycle, op.value.after, None, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 09ea23733..77c39af98 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -21,9 +21,10 @@ use crate::{ div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, ecall::{ Fp2AddInstruction, Fp2MulInstruction, FpAddInstruction, FpMulInstruction, - KeccakInstruction, Secp256k1InvInstruction, Secp256r1InvInstruction, - ShaExtendInstruction, Uint256MulInstruction, WeierstrassAddAssignInstruction, - WeierstrassDecompressInstruction, WeierstrassDoubleAssignInstruction, + KeccakInstruction, PubIoCommitInstruction, Secp256k1InvInstruction, + Secp256r1InvInstruction, ShaExtendInstruction, Uint256MulInstruction, + WeierstrassAddAssignInstruction, WeierstrassDecompressInstruction, + WeierstrassDoubleAssignInstruction, }, logic::{AndInstruction, OrInstruction, XorInstruction}, logic_imm::{AndiInstruction, OriInstruction, XoriInstruction}, @@ -73,7 +74,7 @@ use strum::{EnumCount, IntoEnumIterator}; pub mod mmu; const ECALL_HALT: u32 = Platform::ecall_halt(); -const ECALL_PUB_IO_COMMIT: u32 = Platform::ecall_pub_io_commit(); +const ECALL_PUB_IO_COMMIT: u32 = PubIoCommitSpec::CODE; pub struct Rv32imConfig { // ALU Opcodes. @@ -135,8 +136,7 @@ pub struct Rv32imConfig { // Ecall Opcodes pub halt_config: as Instruction>::InstructionConfig, - pub pubio_commit_config: - as Instruction>::InstructionConfig, + pub pubio_commit_config: as Instruction>::InstructionConfig, pub keccak_config: as Instruction>::InstructionConfig, pub sha_extend_config: as Instruction>::InstructionConfig, pub bn254_add_config: @@ -359,7 +359,7 @@ impl Rv32imConfig { } let halt_config = register_ecall_circuit!(HaltInstruction, ecall_cells_map); let pubio_commit_config = - register_ecall_circuit!(LargeEcallDummy, ecall_cells_map); + register_ecall_circuit!(PubIoCommitInstruction, ecall_cells_map); // Keccak precompile is a known hotspot for peak memory. // Its heavy read/write/LK activity inflates tower-witness usage, causing @@ -568,10 +568,7 @@ impl Rv32imConfig { // system fixed.register_opcode_circuit::>(cs, &self.halt_config); - fixed.register_opcode_circuit::>( - cs, - &self.pubio_commit_config, - ); + fixed.register_opcode_circuit::>(cs, &self.pubio_commit_config); fixed.register_opcode_circuit::>(cs, &self.keccak_config); fixed.register_opcode_circuit::>(cs, &self.sha_extend_config); fixed.register_opcode_circuit::>>( @@ -773,7 +770,7 @@ impl Rv32imConfig { // ecall / halt assign_ecall!(HaltInstruction, halt_config, ECALL_HALT); assign_ecall!( - LargeEcallDummy, + PubIoCommitInstruction, pubio_commit_config, ECALL_PUB_IO_COMMIT ); @@ -1060,7 +1057,7 @@ impl Rv32imConfig { .expect("unable to find name"), ECALL_PUB_IO_COMMIT => *self .ecall_cells_map - .get(&LargeEcallDummy::::name()) + .get(&PubIoCommitInstruction::::name()) .expect("unable to find name"), KeccakSpec::CODE => *self .ecall_cells_map diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index b3f96feb0..5b264dad9 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -5,9 +5,9 @@ use crate::{ structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsInitCircuit, HintsTable, - LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOInitCircuit, - PubIOTable, RegTable, RegTableInitCircuit, ShardRamCircuit, StackInitCircuit, StackTable, - StaticMemInitCircuit, StaticMemTable, TableCircuit, + LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, RegTable, + RegTableInitCircuit, ShardRamCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, + StaticMemTable, TableCircuit, }, }; use ceno_emul::{Addr, IterAddresses, WORD_SIZE, Word}; @@ -20,8 +20,6 @@ pub struct MmuConfig { pub reg_init_config: as TableCircuit>::TableConfig, /// Initialization of memory with static addresses. pub static_mem_init_config: as TableCircuit>::TableConfig, - /// Initialization of public IO. - pub public_io_init_config: as TableCircuit>::TableConfig, /// Initialization of hints. pub hints_init_config: as TableCircuit>::TableConfig, /// Initialization of heap. @@ -41,8 +39,6 @@ impl MmuConfig { let static_mem_init_config = cs.register_table_circuit::>(); - let public_io_init_config = cs.register_table_circuit::>(); - let hints_init_config = cs.register_table_circuit::>(); let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); @@ -52,7 +48,6 @@ impl MmuConfig { Self { reg_init_config, static_mem_init_config, - public_io_init_config, hints_init_config, stack_init_config, heap_init_config, @@ -68,12 +63,10 @@ impl MmuConfig { fixed: &mut ZKVMFixedTraces, reg_init: &[MemInitRecord], static_mem_init: &[MemInitRecord], - io_addrs: &[Addr], ) { assert!( chain!( static_mem_init.iter_addresses(), - io_addrs.iter_addresses(), // TODO: optimize with min_max and Range. self.params.platform.hints.iter_addresses(), ) @@ -88,12 +81,6 @@ impl MmuConfig { &self.static_mem_init_config, static_mem_init, ); - - fixed.register_table_circuit::>( - cs, - &self.public_io_init_config, - io_addrs, - ); fixed.register_table_circuit::>(cs, &self.hints_init_config, &()); fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); @@ -130,7 +117,6 @@ impl MmuConfig { pv: &PublicValues, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], - io_final: &[MemFinalRecord], stack_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { witness.assign_table_circuit::>( @@ -145,12 +131,6 @@ impl MmuConfig { static_mem_final, )?; - witness.assign_table_circuit::>( - cs, - &self.public_io_init_config, - io_final, - )?; - witness.assign_table_circuit::>( cs, &self.stack_init_config, @@ -168,13 +148,11 @@ impl MmuConfig { pv: &PublicValues, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], - io_final: &[MemFinalRecord], hints_final: &[MemFinalRecord], stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { let all_records = vec![ - (PubIOTable::name(), None, io_final), (RegTable::name(), None, reg_final), (StaticMemTable::name(), None, static_mem_final), (StackTable::name(), None, stack_final), @@ -224,10 +202,6 @@ impl MmuConfig { pub fn static_mem_len(&self) -> usize { ::len(&self.params) } - - pub fn public_io_len(&self) -> usize { - ::len(&self.params) - } } pub struct MemPadder { diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 3d9a6e545..f282efddb 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -1,6 +1,7 @@ mod bitwise_keccakf; mod fptower; mod lookup_keccakf; +mod pubio_commit; mod sha256; mod uint256; mod utils; @@ -12,6 +13,7 @@ pub use lookup_keccakf::{ ROUNDS as KECCAK_ROUNDS, ROUNDS_CEIL_LOG2 as KECCAK_ROUNDS_CEIL_LOG2, XOR_LOOKUPS, run_lookup_keccakf, setup_gkr_circuit as setup_lookup_keccak_gkr_circuit, }; +pub use pubio_commit::{PUBIO_COMMIT_WORDS, PUBIO_DIGEST_U16_LIMBS, PubioCommitLayout}; pub use bitwise_keccakf::{ KeccakLayout as BitwiseKeccakLayout, run_keccakf as run_bitwise_keccakf, diff --git a/ceno_zkvm/src/precompiles/pubio_commit.rs b/ceno_zkvm/src/precompiles/pubio_commit.rs new file mode 100644 index 000000000..d4d804fc6 --- /dev/null +++ b/ceno_zkvm/src/precompiles/pubio_commit.rs @@ -0,0 +1,39 @@ +use ff_ext::ExtensionField; +use multilinear_extensions::{Instance, ToExpr}; + +use crate::{ + chip_handler::{MemoryExpr, general::PublicValuesQuery}, + circuit_builder::CircuitBuilder, + error::ZKVMError, + instructions::riscv::constants::UINT_LIMBS, +}; + +pub const PUBIO_COMMIT_WORDS: usize = 8; +pub const PUBIO_DIGEST_U16_LIMBS: usize = PUBIO_COMMIT_WORDS * UINT_LIMBS; + +#[derive(Debug)] +pub struct PubioCommitLayout { + /// Public digest instances laid out as 16-bit limbs, little-endian per word. + pub digest_u16_limbs: [Instance; PUBIO_DIGEST_U16_LIMBS], + pub digest_words: [MemoryExpr; PUBIO_COMMIT_WORDS], +} + +impl PubioCommitLayout { + pub fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let digest_u16_limbs = cb.query_public_io_digest()?; + let digest_words = core::array::from_fn(|word_idx| { + let limb_base = word_idx * UINT_LIMBS; + [ + digest_u16_limbs[limb_base].expr(), + digest_u16_limbs[limb_base + 1].expr(), + ] + }); + + Ok(Self { + digest_u16_limbs, + digest_words, + }) + } +} + + diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 50475fe0f..b16c81675 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -22,7 +22,8 @@ use crate::{ constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, LIMB_BITS, - LIMB_MASK, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + LIMB_MASK, PUBIO_DIGEST_IDX, PUBIO_DIGEST_U16_LIMBS, SHARD_ID_IDX, + SHARD_RW_SUM_IDX, UINT_LIMBS, }, ecall::HaltInstruction, }, @@ -87,7 +88,7 @@ pub struct PublicValues { pub hint_start_addr: u32, pub hint_shard_len: u32, pub public_io: Vec, - pub pubio_digest: [u32; 16], + pub public_io_digest: [u32; 8], pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } @@ -105,7 +106,7 @@ impl PublicValues { hint_start_addr: u32, hint_shard_len: u32, public_io: Vec, - pubio_digest: [u32; 16], + public_io_digest: [u32; 8], shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { Self { @@ -120,7 +121,7 @@ impl PublicValues { hint_start_addr, hint_shard_len, public_io, - pubio_digest, + public_io_digest, shard_rw_sum, } } @@ -144,6 +145,16 @@ impl PublicValues { { E::BaseField::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) } + idx if (PUBIO_DIGEST_IDX..(PUBIO_DIGEST_IDX + PUBIO_DIGEST_U16_LIMBS)) + .contains(&idx) => + { + let digest_limb_idx = idx - PUBIO_DIGEST_IDX; + let word_idx = digest_limb_idx / UINT_LIMBS; + let limb_idx = digest_limb_idx % UINT_LIMBS; + E::BaseField::from_canonical_u32( + (self.public_io_digest[word_idx] >> (limb_idx * LIMB_BITS)) & LIMB_MASK, + ) + } _ => panic!("public value index {index} out of range"), } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 96b28ce1a..123af4dda 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -109,11 +109,15 @@ impl< } } - pub fn setup_init_mem(&self, hints: &[u32], public_io: &[u32]) -> crate::e2e::InitMemState { + pub fn setup_init_mem( + &self, + hints: &[u32], + public_io_digest_input: &[u32], + ) -> crate::e2e::InitMemState { let Some(ctx) = self.pk.program_ctx.as_ref() else { panic!("empty program ctx") }; - ctx.setup_init_mem(hints, public_io) + ctx.setup_init_mem(hints, public_io_digest_input) } } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index d8e1c9029..1bed1023a 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -396,7 +396,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 16], [0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 8], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 130adc787..505178944 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -2,7 +2,7 @@ use ceno_emul::{Addr, VM_REG_COUNT, WORD_SIZE}; use ff_ext::ExtensionField; use gkr_iop::error::CircuitBuilderError; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit, PubIORamInitCircuit}; +use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit}; use crate::{ instructions::riscv::constants::UINT_LIMBS, @@ -241,22 +241,4 @@ impl NonVolatileTable for StaticMemTable { pub type StaticMemInitCircuit = NonVolatileRamCircuit>; -#[derive(Clone)] -pub struct PubIOTable; - -impl NonVolatileTable for PubIOTable { - const RAM_TYPE: RAMType = RAMType::Memory; - const V_LIMBS: usize = UINT_LIMBS; - const WRITABLE: bool = false; - - fn name() -> &'static str { - "PubIOTable" - } - - fn len(params: &ProgramParams) -> usize { - params.pubio_len - } -} - -pub type PubIOInitCircuit = PubIORamInitCircuit; pub type LocalFinalCircuit = LocalFinalRamCircuit; diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 27df29e90..e98a827f2 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -775,7 +775,7 @@ mod tests { 0, 0, vec![0], // dummy - [0; 16], + [0; 8], shard_rw_sum, ); From a900a587052818d8d20c59ab7194d1d87a60284e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 16:11:37 +0800 Subject: [PATCH 3/6] Fix PUB_IO commit witness alignment and simplify public-io handling --- ceno_cli/src/commands/common_args/ceno.rs | 26 +----- ceno_emul/src/lib.rs | 3 +- ceno_emul/src/platform.rs | 9 +- ceno_emul/src/syscalls/pubio_commit.rs | 3 - ceno_rt/src/mmio.rs | 2 +- ceno_zkvm/src/bin/e2e.rs | 22 +---- ceno_zkvm/src/chip_handler/general.rs | 5 +- ceno_zkvm/src/e2e.rs | 57 +++++++----- ceno_zkvm/src/instructions/riscv/constants.rs | 3 +- .../instructions/riscv/ecall/pubio_commit.rs | 34 ++++--- ceno_zkvm/src/instructions/riscv/rv32im.rs | 8 +- ceno_zkvm/src/precompiles/pubio_commit.rs | 2 - ceno_zkvm/src/tables/ram/ram_circuit.rs | 56 +----------- ceno_zkvm/src/tables/ram/ram_impl.rs | 91 ------------------- examples/examples/sha256.rs | 6 -- 15 files changed, 67 insertions(+), 260 deletions(-) diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 75be1322b..15e736507 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -169,25 +169,9 @@ impl CenoOptions { self.heap_size.next_multiple_of(WORD_SIZE as u32) } - /// Read the public io into ceno stdin - pub fn read_public_io(&self) -> anyhow::Result> { - if let Some(public_io) = &self.public_io { - // if vector contains only one element, write it as a raw `u32` - // otherwise, write the entire vector - // in both cases, convert the resulting `CenoStdin` into a `Vec` - if public_io.len() == 1 { - CenoStdin::default() - .write(&public_io[0]) - .map(|stdin| Into::>::into(&*stdin)) - } else { - CenoStdin::default() - .write(public_io) - .map(|stdin| Into::>::into(&*stdin)) - } - .context("failed to get public_io".to_string()) - } else { - Ok(vec![]) - } + /// Read raw public-io words; digesting happens later in the zkVM pipeline. + pub fn read_public_io(&self) -> Vec { + self.public_io.clone().unwrap_or_default() } /// Read the hints @@ -367,9 +351,7 @@ fn run_elf_inner< options.max_cycle_per_shard, ); - let public_io_digest_input = options - .read_public_io() - .context("failed to read public io")?; + let public_io_digest_input = options.read_public_io(); let public_io_digest = public_io_words_to_digest_words(&public_io_digest_input); tracing::debug!("public io digest words: {:?}", public_io_digest); let public_io_size = options.public_io_size; diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 10ebcf587..8bf8eacb1 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -31,7 +31,7 @@ pub mod disassemble; mod syscalls; pub use syscalls::{ BLS12381_ADD, BLS12381_DECOMPRESS, BLS12381_DOUBLE, BN254_ADD, BN254_DOUBLE, BN254_FP_ADD, - BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, + BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, PubIoCommitSpec, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, UINT256_MUL, @@ -41,7 +41,6 @@ pub use syscalls::{ }, keccak_permute::{KECCAK_WORDS, KeccakSpec}, phantom::LogPcCycleSpec, - PubIoCommitSpec, secp256k1::{ COORDINATE_WORDS as SECP256K1_COORDINATE_WORDS, SECP256K1_ARG_WORDS, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index bfcce5cff..16e7eddb0 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -91,7 +91,7 @@ impl Display for Platform { // │ // └───────────────────────────── 0x8000_0000 (rom base) pub static CENO_PLATFORM: Lazy = Lazy::new(|| Platform { - rom: 0x0800_0000..0x1000_0000, // 128 MB + rom: 0x0800_0000..0x1000_0000, // 128 MB stack: 0x1000_0000..0x2000_4000, // stack grows downward, 0x4000 reserved for debug io. // we make hints start from 0x2800_0000 thus reserve a 128MB gap for debug io // at the end of stack @@ -180,12 +180,7 @@ impl Platform { /// Validate the platform configuration, range shall not overlap. pub fn validate(&self) -> bool { - let mut ranges = [ - &self.rom, - &self.stack, - &self.heap, - &self.hints, - ]; + let mut ranges = [&self.rom, &self.stack, &self.heap, &self.hints]; ranges.sort_by_key(|r| r.start); for i in 0..ranges.len() - 1 { if ranges[i].end > ranges[i + 1].start { diff --git a/ceno_emul/src/syscalls/pubio_commit.rs b/ceno_emul/src/syscalls/pubio_commit.rs index e182383ee..aa49281d8 100644 --- a/ceno_emul/src/syscalls/pubio_commit.rs +++ b/ceno_emul/src/syscalls/pubio_commit.rs @@ -23,6 +23,3 @@ pub fn pubio_commit(vm: &VMState) -> SyscallEffects { next_pc: None, } } - - - diff --git a/ceno_rt/src/mmio.rs b/ceno_rt/src/mmio.rs index 62c4a847b..cde0f42a7 100644 --- a/ceno_rt/src/mmio.rs +++ b/ceno_rt/src/mmio.rs @@ -1,8 +1,8 @@ //! Memory-mapped I/O (MMIO) functions. use ceno_serde::from_slice; -use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; use ceno_syscall::syscall_pub_io_commit; +use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; use serde::de::DeserializeOwned; use tiny_keccak::{Hasher, Keccak}; diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 396321724..67af854e6 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -174,26 +174,8 @@ fn main() { .with(args.profiling.is_none().then_some(default_filter)) .init(); - // process public input first - let public_io = args - .public_io - .and_then(|public_io| { - // if the vector contains only one element, write it as a raw `u32` - // otherwise, write the entire vector - // in both cases, convert the resulting `CenoStdin` into a `Vec` - if public_io.len() == 1 { - CenoStdin::default() - .write(&public_io[0]) - .ok() - .map(|stdin| Into::>::into(&*stdin)) - } else { - CenoStdin::default() - .write(&public_io) - .ok() - .map(|stdin| Into::>::into(&*stdin)) - } - }) - .unwrap_or_default(); + // process public input first; this is raw u32 public input, not pre-digested words. + let public_io = args.public_io.unwrap_or_default(); assert!( public_io.len() <= args.public_io_size as usize / WORD_SIZE, "require pub io length {} < max public_io_size {}", diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 8606baabd..3d629b0aa 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,8 +5,8 @@ use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, - HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - PUBIO_DIGEST_IDX, PUBIO_DIGEST_U16_LIMBS, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBIO_DIGEST_IDX, + PUBIO_DIGEST_U16_LIMBS, PUBLIC_IO_IDX, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, }, scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, @@ -24,6 +24,7 @@ pub trait PublicValuesQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; + #[allow(dead_code)] fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; fn query_public_io_digest( &mut self, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 91f7106db..c2b74d730 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -34,14 +34,13 @@ use ff_ext::{ExtensionField, SmallField}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; use gkr_iop::{RAMType, hal::ProverBackend}; +use itertools::Itertools; #[cfg(debug_assertions)] -use itertools::MinMaxResult; -use itertools::{Itertools, chain}; +use itertools::{MinMaxResult, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; use rustc_hash::FxHashSet; use serde::Serialize; -use tiny_keccak::{Hasher, Keccak}; #[cfg(debug_assertions)] use std::collections::{HashMap, HashSet}; use std::{ @@ -50,6 +49,7 @@ use std::{ ops::Range, sync::Arc, }; +use tiny_keccak::{Hasher, Keccak}; use tracing::info_span; use transcript::BasicTranscript as Transcript; use witness::next_pow2_instance_padding; @@ -66,22 +66,9 @@ pub fn public_io_words_to_digest_words(words: &[u32]) -> [u32; 8] { } let mut digest = [0u8; 32]; keccak.finalize(&mut digest); - #[cfg(target_endian = "little")] - { - // Reinterpret Keccak digest bytes as 8 little-endian u32 words. - unsafe { core::mem::transmute::<[u8; 32], [u32; 8]>(digest) } - } - #[cfg(not(target_endian = "little"))] - { - core::array::from_fn(|i| { - u32::from_le_bytes([ - digest[i * 4], - digest[i * 4 + 1], - digest[i * 4 + 2], - digest[i * 4 + 3], - ]) - }) - } + + // Reinterpret Keccak digest bytes as 8 little-endian u32 words. + unsafe { core::mem::transmute::<[u8; 32], [u32; 8]>(digest) } } // define a relative small number to make first shard handle much less instruction @@ -1572,12 +1559,8 @@ pub fn setup_program( let system_config = construct_configs::(program_params); let reg_init = system_config.mmu_config.initial_registers(); // Generate fixed traces - let zkvm_fixed_traces = generate_fixed_traces( - &system_config, - ®_init, - &static_addrs, - &program, - ); + let zkvm_fixed_traces = + generate_fixed_traces(&system_config, ®_init, &static_addrs, &program); E2EProgramCtx { program: Arc::new(program), @@ -2135,6 +2118,7 @@ mod tests { use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; use itertools::Itertools; use std::sync::Arc; + use tiny_keccak::{Hasher, Keccak}; #[test] fn test_single_prover_shard_ctx() { @@ -2264,4 +2248,27 @@ mod tests { } } } + + #[test] + fn public_io_digest_matches_guest_commit_hashing() { + let words = vec![4191u32]; + let digest_words = super::public_io_words_to_digest_words(&words); + + let mut keccak = Keccak::v256(); + for word in &words { + keccak.update(&word.to_le_bytes()); + } + let mut digest = [0u8; 32]; + keccak.finalize(&mut digest); + let expected = core::array::from_fn(|i| { + u32::from_le_bytes([ + digest[i * 4], + digest[i * 4 + 1], + digest[i * 4 + 2], + digest[i * 4 + 3], + ]) + }); + + assert_eq!(digest_words, expected); + } } diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 469bde7ac..09feb6280 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -1,5 +1,4 @@ -use crate::uint::UIntLimbs; -use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; +use crate::{scheme::constants::SEPTIC_EXTENSION_DEGREE, uint::UIntLimbs}; pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs index f3e6509ee..8ca50708b 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs @@ -1,6 +1,8 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, Platform, PubIoCommitSpec, StepRecord, SyscallSpec, WORD_SIZE, WriteOp}; +use ceno_emul::{ + Change, InsnKind, Platform, PubIoCommitSpec, StepRecord, SyscallSpec, WORD_SIZE, WriteOp, +}; use ff_ext::ExtensionField; use multilinear_extensions::ToExpr; use p3::field::FieldAlgebra; @@ -15,7 +17,7 @@ use crate::{ riscv::{ constants::{LIMB_BITS, LIMB_MASK, MEM_BITS, UInt}, ecall_base::OpFixedRS, - insn_base::{MemAddr, ReadMEM, StateInOut}, + insn_base::{MemAddr, StateInOut, WriteMEM}, }, }, precompiles::{PUBIO_COMMIT_WORDS, PubioCommitLayout}, @@ -29,7 +31,7 @@ pub struct EcallPubioCommitConfig { vm_state: StateInOut, ecall_id: OpFixedRS, digest_ptr: (OpFixedRS, MemAddr), - mem_read: [ReadMEM; PUBIO_COMMIT_WORDS], + mem_rw: [WriteMEM; PUBIO_COMMIT_WORDS], } pub struct PubIoCommitInstruction(PhantomData); @@ -82,13 +84,14 @@ impl Instruction for PubIoCommitInstruction { ))?; let layout = PubioCommitLayout::construct_circuit(cb)?; - let mem_read: [ReadMEM; PUBIO_COMMIT_WORDS] = (0..PUBIO_COMMIT_WORDS) + let mem_rw: [WriteMEM; PUBIO_COMMIT_WORDS] = (0..PUBIO_COMMIT_WORDS) .map(|i| { - ReadMEM::construct_circuit( + WriteMEM::construct_circuit( cb, digest_ptr.prev_value.as_ref().unwrap().value() + E::BaseField::from_canonical_u32((i * WORD_SIZE) as u32).expr(), layout.digest_words[i].clone(), + layout.digest_words[i].clone(), vm_state.ts, ) }) @@ -100,7 +103,7 @@ impl Instruction for PubIoCommitInstruction { vm_state, ecall_id, digest_ptr: (digest_ptr, digest_ptr_value), - mem_read, + mem_rw, }) } @@ -135,10 +138,11 @@ impl Instruction for PubIoCommitInstruction { ), )?; - config - .digest_ptr - .1 - .assign_instance(instance, lk_multiplicity, ops.reg_ops[0].value.after)?; + config.digest_ptr.1.assign_instance( + instance, + lk_multiplicity, + ops.reg_ops[0].value.after, + )?; config.digest_ptr.0.assign_op( instance, shard_ctx, @@ -147,17 +151,11 @@ impl Instruction for PubIoCommitInstruction { &ops.reg_ops[0], )?; - for (reader, op) in config.mem_read.iter().zip(&ops.mem_ops) { - reader.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; + for (writer, op) in config.mem_rw.iter().zip(&ops.mem_ops) { + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } lk_multiplicity.fetch(step.pc().before.0); Ok(()) } } - - - - - - diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 77c39af98..90d53fc21 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -46,10 +46,10 @@ use ceno_emul::{ Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, InsnKind::{self, *}, - KeccakSpec, LogPcCycleSpec, Platform, PubIoCommitSpec, Secp256k1AddSpec, Secp256k1DecompressSpec, - Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, Secp256r1DoubleSpec, - Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, StepRecord, - SyscallSpec, Uint256MulSpec, Word, + KeccakSpec, LogPcCycleSpec, Platform, PubIoCommitSpec, Secp256k1AddSpec, + Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, + Secp256r1DoubleSpec, Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepIndex, + StepRecord, SyscallSpec, Uint256MulSpec, Word, }; use dummy::LargeEcallDummy; use ff_ext::ExtensionField; diff --git a/ceno_zkvm/src/precompiles/pubio_commit.rs b/ceno_zkvm/src/precompiles/pubio_commit.rs index d4d804fc6..35cf93eb4 100644 --- a/ceno_zkvm/src/precompiles/pubio_commit.rs +++ b/ceno_zkvm/src/precompiles/pubio_commit.rs @@ -35,5 +35,3 @@ impl PubioCommitLayout { }) } } - - diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..0afa819be 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,6 +1,4 @@ -use super::ram_impl::{ - LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableInitConfig, -}; +use super::ram_impl::{LocalFinalRAMTableConfig, NonVolatileTableConfigTrait}; use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, @@ -123,58 +121,6 @@ impl< } } -/// PubIORamCircuit initializes and finalizes memory -/// - at fixed addresses, -/// - with content from the public input of proofs. -/// -/// This circuit does not and cannot decide whether the memory is mutable or not. -/// It supports LOAD where the program reads the public input, -/// or STORE where the memory content must equal the public input after execution. -pub struct PubIORamInitCircuit(PhantomData<(E, R)>); - -impl TableCircuit - for PubIORamInitCircuit -{ - type TableConfig = PubIOTableInitConfig; - type FixedInput = [Addr]; - type WitnessInput<'a> = [MemFinalRecord]; - - fn name() -> String { - format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name()) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - cb.set_omc_init_only(); - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) - } - - fn generate_fixed_traces( - config: &Self::TableConfig, - num_fixed: usize, - io_addrs: &[Addr], - ) -> RowMajorMatrix { - // assume returned table is well-formed including padding - config.gen_init_state(num_fixed, io_addrs) - } - - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - num_structural_witin: usize, - _multiplicity: &[HashMap], - final_mem: &[MemFinalRecord], - ) -> Result, ZKVMError> { - // assume returned table is well-formed including padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_mem)?) - } -} - /// - **Dynamic**: The address space is bounded within a specific range, /// though the range itself may be dynamically determined per proof. /// - **Volatile**: The initial values are set to `0` diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 7e5f1293d..78049ecb3 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -14,7 +14,6 @@ use super::{ ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, }; use crate::{ - chip_handler::general::PublicValuesQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, e2e::ShardContext, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, @@ -161,96 +160,6 @@ impl NonVolatileTableConfigTrait< } } -/// define public io -/// init value set by instance -#[derive(Clone, Debug)] -pub struct PubIOTableInitConfig { - addr: Fixed, - phantom: PhantomData, - params: ProgramParams, -} - -impl PubIOTableInitConfig { - pub fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - assert!(!NVRAM::WRITABLE); - let init_v = cb.query_public_io()?; - let addr = cb.create_fixed(|| "addr"); - - let init_table = [ - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![Expression::Fixed(addr)], - init_v.iter().map(|v| v.expr_as_instance()).collect_vec(), - vec![Expression::ZERO], // Initial cycle. - ] - .concat(); - - cb.w_table_record( - || "init_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - init_table, - )?; - - Ok(Self { - addr, - phantom: PhantomData, - params: params.clone(), - }) - } - - /// assign to fixed address - pub fn gen_init_state( - &self, - num_fixed: usize, - io_addrs: &[Addr], - ) -> RowMajorMatrix { - assert!(NVRAM::len(&self.params).is_power_of_two()); - - let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), - num_fixed, - InstancePaddingStrategy::Default, - ); - assert_eq!(init_table.num_padding_instances(), 0); - - init_table - .par_rows_mut() - .zip_eq(io_addrs) - .for_each(|(row, addr)| { - set_fixed_val!(row, self.addr, (*addr as u64).into_f()); - }); - init_table - } - - /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, - _num_witin: usize, - num_structural_witin: usize, - final_mem: &[MemFinalRecord], - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - if final_mem.is_empty() { - return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); - } - assert!(num_structural_witin == 0 || num_structural_witin == 1); - let mut value = Vec::with_capacity(NVRAM::len(&self.params)); - value.par_extend( - (0..NVRAM::len(&self.params)) - .into_par_iter() - .map(|_| F::ONE), - ); - let structural_witness = - RowMajorMatrix::::new_by_values(value, 1, InstancePaddingStrategy::Default); - Ok([RowMajorMatrix::empty(), structural_witness]) - } -} - /// volatile with all init value as 0 /// dynamic address as witin, relied on augment of knowledge to prove address form #[derive(Clone, Debug)] diff --git a/examples/examples/sha256.rs b/examples/examples/sha256.rs index d9239b536..05dfb0a27 100644 --- a/examples/examples/sha256.rs +++ b/examples/examples/sha256.rs @@ -10,12 +10,6 @@ fn main() { let h = Sha256::digest(&input); let h_bytes: [u8; 32] = h.into(); - let h: [u32; 8] = core::array::from_fn(|i| { - let chunk = &h_bytes[4 * i..][..4]; - u32::from_be_bytes(chunk.try_into().unwrap()) - }); - // Output the final hash values one by one ceno_rt::commit(&h_bytes); - // debug_print!("{:x}", h[0]); } From 1a0b84f4052784c1f8ef95e7291c79778d41e767 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 17:07:20 +0800 Subject: [PATCH 4/6] refactor(zkvm): unify scalar PI path across proving and recursion - rename ProofInput.pi_evals to pi - remove pi_in_evals and instance_openings plumbing - drop recursion mles/raw_pi_num_variables fields and keep scalar PI conversion path --- ceno_recursion/src/aggregation/internal.rs | 2 +- ceno_recursion/src/aggregation/mod.rs | 6 +- ceno_recursion/src/zkvm_verifier/binding.rs | 60 +++--------------- ceno_recursion/src/zkvm_verifier/verifier.rs | 61 ++++--------------- ceno_zkvm/benches/riscv_add.rs | 3 +- ceno_zkvm/src/chip_handler/general.rs | 4 +- ceno_zkvm/src/e2e.rs | 1 - ceno_zkvm/src/precompiles/bitwise_keccakf.rs | 1 - ceno_zkvm/src/precompiles/fptower/fp.rs | 1 - .../src/precompiles/fptower/fp2_addsub.rs | 1 - ceno_zkvm/src/precompiles/fptower/fp2_mul.rs | 1 - ceno_zkvm/src/precompiles/lookup_keccakf.rs | 1 - ceno_zkvm/src/precompiles/sha256/extend.rs | 1 - ceno_zkvm/src/precompiles/uint256.rs | 1 - .../weierstrass/weierstrass_add.rs | 1 - .../weierstrass/weierstrass_decompress.rs | 1 - .../weierstrass/weierstrass_double.rs | 1 - ceno_zkvm/src/scheme.rs | 28 --------- ceno_zkvm/src/scheme/cpu/mod.rs | 22 ++----- ceno_zkvm/src/scheme/gpu/mod.rs | 22 ++----- ceno_zkvm/src/scheme/hal.rs | 4 +- ceno_zkvm/src/scheme/mock_prover.rs | 36 +++++------ ceno_zkvm/src/scheme/prover.rs | 35 ++--------- ceno_zkvm/src/scheme/tests.rs | 5 +- ceno_zkvm/src/scheme/utils.rs | 11 +--- ceno_zkvm/src/scheme/verifier.rs | 27 +++----- ceno_zkvm/src/structs.rs | 4 +- ceno_zkvm/src/tables/shard_ram.rs | 44 ++----------- gkr_iop/src/chip.rs | 3 +- gkr_iop/src/circuit_builder.rs | 27 ++------ gkr_iop/src/gkr.rs | 2 - gkr_iop/src/gkr/layer.rs | 16 ++--- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 25 +------- gkr_iop/src/gkr/layer_constraint_system.rs | 2 - 34 files changed, 87 insertions(+), 373 deletions(-) diff --git a/ceno_recursion/src/aggregation/internal.rs b/ceno_recursion/src/aggregation/internal.rs index bf0fdfac2..568f1c17b 100644 --- a/ceno_recursion/src/aggregation/internal.rs +++ b/ceno_recursion/src/aggregation/internal.rs @@ -128,7 +128,7 @@ impl NonLeafVerifierVariables { // let expected_last_shard_id = Usize::uninit(builder); // builder.assign(&expected_last_shard_id, pv.len() - Usize::from(1)); - // let shard_id_fs = builder.get(&shard_raw_pi, SHARD_ID_IDX); + // let shard_id_fs = builder.get(&shard_pi, SHARD_ID_IDX); // let shard_id_f = builder.get(&shard_id_fs, 0); // let shard_id = Usize::Var(builder.cast_felt_to_var(shard_id_f)); // builder.assert_usize_eq(expected_last_shard_id, shard_id); diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index c8d94fedb..8d3706f8d 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -265,7 +265,7 @@ impl CenoAggregationProver { .collect(); let user_public_values: Vec = zkvm_proof_inputs .iter() - .flat_map(|p| p.raw_pi.to_vec()) + .flat_map(|p| p.pi.to_vec()) .collect(); let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs); @@ -398,7 +398,7 @@ impl CenoLeafVmVerifierConfig { builder.cycle_tracker_start("Verify Ceno ZKVM Proof"); let zkvm_proof = ceno_leaf_input.proof; - let raw_pi = zkvm_proof.raw_pi.clone(); + let pi = zkvm_proof.pi.clone(); let _calculated_shard_ec_sum = verify_zkvm_proof(&mut builder, zkvm_proof, &self.vk); builder.cycle_tracker_end("Verify Ceno ZKVM Proof"); @@ -409,7 +409,7 @@ impl CenoLeafVmVerifierConfig { builder.assign(&stark_pvs.app_commit[i], F::ZERO); } - let pv = &raw_pi; + let pv = π let init_pc = builder.get(pv, INIT_PC_IDX); let end_pc = builder.get(pv, END_PC_IDX); let exit_code = builder.get(pv, EXIT_CODE_IDX); diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index aacb10601..08a41e2d9 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -14,7 +14,6 @@ use crate::{ }, }; use ceno_zkvm::{ - instructions::riscv::constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, scheme::{ZKVMChipProof, ZKVMProof}, structs::{EccQuarkProof, TowerProofs, ZKVMVerifyingKey}, }; @@ -42,7 +41,7 @@ pub type E = BinomialExtensionField; pub type RecPcs = Basefold; pub type InnerConfig = AsmConfig; -fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec { +fn pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec { vec![ F::from_canonical_u32(public_values.exit_code & 0xffff), F::from_canonical_u32((public_values.exit_code >> 16) & 0xffff), @@ -66,24 +65,6 @@ fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> .collect_vec() } -fn mles_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec> { - (0..UINT_LIMBS) - .map(|limb_index| { - public_values - .public_io - .iter() - .map(|value| { - F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16) - }) - .collect_vec() - }) - .collect_vec() -} - -fn pi_evals_from_raw_pi(raw_pi: &[F]) -> Vec { - raw_pi.iter().map(|v| E::from(*v)).collect_vec() -} - pub fn decompose_minus_one_bits(n: usize) -> Vec { let a = if n > 0 { n - 1 } else { 0 }; let mut bit_decomp: Vec = vec![]; @@ -114,10 +95,7 @@ pub fn decompose_prefixed_layer_bits(n: usize) -> (Vec, Vec>) { #[derive(DslVariable, Clone)] pub struct ZKVMProofInputVariable { pub shard_id: Usize, - pub raw_pi: Array>, - pub mles: Array>>, - pub raw_pi_num_variables: Array>, - pub pi_evals: Array>, + pub pi: Array>, pub chip_proofs: Array>>, pub max_num_var: Var, pub max_width: Var, @@ -139,11 +117,7 @@ pub struct TowerProofInputVariable { pub(crate) struct ZKVMProofInput { pub shard_id: usize, - pub raw_pi: Vec, - pub mles: Vec>, - pub raw_pi_num_variables: Vec, - // Evaluation of raw_pi. - pub pi_evals: Vec, + pub pi: Vec, pub chip_proofs: BTreeMap, pub witin_commit: BasefoldCommitment, pub opening_proof: BasefoldProof, @@ -155,13 +129,7 @@ impl ZKVMProofInput { zkvm_proof: ZKVMProof, vk: &ZKVMVerifyingKey, ) -> Self { - let raw_pi = raw_pi_from_public_values(&zkvm_proof.public_values); - let mles = mles_from_public_values(&zkvm_proof.public_values); - let raw_pi_num_variables = mles - .iter() - .map(|v| ceil_log2(v.len().next_power_of_two())) - .collect::>(); - let pi_evals = pi_evals_from_raw_pi(&raw_pi); + let pi = pi_from_public_values(&zkvm_proof.public_values); let mut chip_witin_num_vars: HashMap = HashMap::new(); // (chip_id, (num_witin, num_fixed)) let mut chip_indices = zkvm_proof @@ -190,10 +158,7 @@ impl ZKVMProofInput { ZKVMProofInput { shard_id, - raw_pi, - mles, - raw_pi_num_variables, - pi_evals, + pi, chip_proofs: zkvm_proof .chip_proofs .into_iter() @@ -224,10 +189,7 @@ impl Hintable for ZKVMProofInput { fn read(builder: &mut Builder) -> Self::HintVariable { let shard_id = Usize::Var(usize::read(builder)); - let raw_pi = Vec::::read(builder); - let mles = Vec::>::read(builder); - let raw_pi_num_variables = Vec::::read(builder); - let pi_evals = Vec::::read(builder); + let pi = Vec::::read(builder); builder.cycle_tracker_start("read chip proofs"); let chip_proofs = Vec::::read(builder); builder.cycle_tracker_end("read chip proofs"); @@ -243,10 +205,7 @@ impl Hintable for ZKVMProofInput { ZKVMProofInputVariable { shard_id, - raw_pi, - mles, - raw_pi_num_variables, - pi_evals, + pi, chip_proofs, max_num_var, max_width, @@ -316,10 +275,7 @@ impl Hintable for ZKVMProofInput { let fixed_perm = get_perm(fixed_num_vars); stream.extend(>::write(&self.shard_id)); - stream.extend(self.raw_pi.write()); - stream.extend(self.mles.write()); - stream.extend(self.raw_pi_num_variables.write()); - stream.extend(self.pi_evals.write()); + stream.extend(self.pi.write()); stream.extend(vec![vec![F::from_canonical_usize(self.chip_proofs.len())]]); for proofs in self.chip_proofs.values() { stream.extend(proofs.write()); diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 743ce9ad1..365d892c7 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -108,22 +108,13 @@ pub fn verify_zkvm_proof>( let logup_sum: Ext = builder.constant(C::EF::ZERO); for (_, circuit_vk) in vk.circuit_vks.iter() { - for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { - let raw = builder.get(&zkvm_proof_input.raw_pi, instance_value.0); + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { + let raw = builder.get(&zkvm_proof_input.pi, instance_value.0); // Match native verifier transcript behavior: append base-field PI element directly. challenger.observe(builder, raw); } } - iter_zip!(builder, zkvm_proof_input.raw_pi, zkvm_proof_input.pi_evals).for_each( - |ptr_vec, builder| { - let raw = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); - let eval = builder.iter_ptr_get(&zkvm_proof_input.pi_evals, ptr_vec[1]); - let raw_ext = builder.ext_from_base_slice(&[raw]); - builder.assert_ext_eq(raw_ext, eval); - }, - ); - builder .if_eq(zkvm_proof_input.shard_id.clone(), Usize::from(0)) .then(|builder| { @@ -330,9 +321,7 @@ pub fn verify_zkvm_proof>( builder, &mut chip_challenger, &chip_proof, - &zkvm_proof_input.raw_pi, - &zkvm_proof_input.mles, - &zkvm_proof_input.raw_pi_num_variables, + &zkvm_proof_input.pi, &challenges, chip_vk, &unipoly_extrapolator, @@ -367,12 +356,11 @@ pub fn verify_zkvm_proof>( let point_clone: Array> = builder.eval(input_opening_point.clone()); - let (wits_in_evals, fixed_in_evals, _pi_in_evals) = split_input_opening_evals( + let (wits_in_evals, fixed_in_evals) = split_input_opening_evals( builder, &chip_proof, circuit_vk.get_cs().num_witin(), circuit_vk.get_cs().num_fixed(), - circuit_vk.get_cs().instance_openings().len(), ); if circuit_vk.get_cs().num_witin() > 0 { @@ -530,7 +518,7 @@ pub fn verify_zkvm_proof>( .into_iter() .enumerate() .for_each(|(i, idx)| { - let raw = builder.get(&zkvm_proof_input.raw_pi, idx); + let raw = builder.get(&zkvm_proof_input.pi, idx); let eval = builder.ext_from_base_slice(&[raw]); builder.set(&global_state_pi_evals, i, eval); }); @@ -573,12 +561,7 @@ fn split_input_opening_evals( chip_proof: &ZKVMChipProofInputVariable, num_witin: usize, num_fixed: usize, - num_pi: usize, -) -> ( - Array>, - Array>, - Array>, -) { +) -> (Array>, Array>) { let last_layer_idx: Usize = builder.eval(chip_proof.gkr_iop_proof.layer_proofs.len() - Usize::from(1)); let last_layer = builder.get(&chip_proof.gkr_iop_proof.layer_proofs, last_layer_idx); @@ -586,15 +569,13 @@ fn split_input_opening_evals( let wit_end = Usize::from(num_witin); let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); - let pi_end: Usize = builder.eval(fixed_end.clone() + Usize::from(num_pi)); // Native verifier accepts extra trailing evals; only the prefix is consumed here. // Keep recursion semantics aligned by slicing the required prefix. - let eval_prefix = main_evals.slice(builder, Usize::from(0), pi_end.clone()); + let eval_prefix = main_evals.slice(builder, Usize::from(0), fixed_end.clone()); ( eval_prefix.slice(builder, Usize::from(0), wit_end), eval_prefix.slice(builder, Usize::from(num_witin), fixed_end), - eval_prefix.slice(builder, Usize::from(num_witin + num_fixed), pi_end), ) } @@ -603,9 +584,7 @@ pub fn verify_chip_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, chip_proof: &ZKVMChipProofInputVariable, - raw_pi: &Array>, - mles: &Array>>, - raw_pi_num_variables: &Array>, + pi: &Array>, challenges: &Array>, vk: &VerifyingKey, unipoly_extrapolator: &UniPolyExtrapolator, @@ -752,9 +731,9 @@ pub fn verify_chip_proof( }); let gkr_circuit = gkr_circuit.clone().unwrap(); let circuit_pi_evals: Array> = - builder.dyn_array(Usize::from(cs.instance_values.len())); - for (i, instance) in cs.instance_values.iter().enumerate() { - let raw = builder.get(raw_pi, instance.0); + builder.dyn_array(Usize::from(cs.instance.len())); + for (i, instance) in cs.instance.iter().enumerate() { + let raw = builder.get(pi, instance.0); let eval = builder.ext_from_base_slice(&[raw]); builder.set(&circuit_pi_evals, i, eval); } @@ -857,8 +836,6 @@ pub fn verify_chip_proof( &chip_proof.gkr_iop_proof, challenges, &circuit_pi_evals, - mles, - raw_pi_num_variables, &out_evals, selector_ctxs, unipoly_extrapolator, @@ -877,12 +854,10 @@ pub fn verify_gkr_circuit( gkr_proof: &GKRProofVariable, challenges: &Array>, pub_io_evals: &Array>, - mles: &Array>>, - raw_pi_num_variables: &Array>, claims: &Array>, selector_ctxs: Vec>, unipoly_extrapolator: &UniPolyExtrapolator, - poly_evaluator: &mut PolyEvaluator, + _poly_evaluator: &mut PolyEvaluator, ) -> PointVariable { let rt = PointVariable { fs: builder.dyn_array(0), @@ -1171,18 +1146,6 @@ pub fn verify_gkr_circuit( builder.assert_ext_eq(expected_eval, main_wit_eval); } - let pubio_offset = layer.n_witin + layer.n_fixed; - for (index, instance) in layer.instance_openings.iter().enumerate() { - let index: usize = pubio_offset + index; - let poly = builder.get(mles, instance.0); - let num_variable = builder.get(raw_pi_num_variables, instance.0); - let in_point_slice = in_point.slice(builder, 0, num_variable); - let expected_eval = - poly_evaluator.evaluate_base_poly_at_point(builder, &poly, &in_point_slice); - let main_eval = builder.get(&main_evals, index); - builder.assert_ext_eq(expected_eval, main_eval); - } - // TODO: we should store alpha_pows in a bigger array to avoid concatenating them let main_sumcheck_challenges_len: Usize = builder.eval(alpha_pows.len() + Usize::from(2)); diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 3823219d7..ab7e254d3 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -112,8 +112,7 @@ fn bench_add(c: &mut Criterion) { fixed: vec![], witness: polys, structural_witness: vec![], - public_input: vec![], - pub_io_evals: vec![], + pi: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 3d629b0aa..ef2a5f9bf 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -90,8 +90,8 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ - self.cs.query_instance_for_openings(PUBLIC_IO_IDX)?, - self.cs.query_instance_for_openings(PUBLIC_IO_IDX + 1)?, + self.cs.query_instance(PUBLIC_IO_IDX)?, + self.cs.query_instance(PUBLIC_IO_IDX + 1)?, ]) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c2b74d730..1fbce7251 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1026,7 +1026,6 @@ pub fn emulate_program<'a>( heap_final.len() as u32, platform.hints.start, hints_final.len() as u32, - vec![], public_io_words_to_digest_words(public_io_digest_input), [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index 78c8bfe95..76f6eb4a5 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -1003,7 +1003,6 @@ pub fn run_keccakf + 'stat &out_evals, &[], &[], - &[], &mut verifier_transcript, &selector_ctxs, ) diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index ba7c04308..a38b81e24 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -478,7 +478,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index 76d32b31a..e11a3fc13 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -523,7 +523,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index c9160e6d9..352a01a54 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -536,7 +536,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 52391267a..5dda9f636 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -1263,7 +1263,6 @@ pub fn run_lookup_keccakf gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs index 235e37b95..3f87e2395 100644 --- a/ceno_zkvm/src/precompiles/sha256/extend.rs +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -554,7 +554,6 @@ mod tests { gkr_proof, &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index e59e97c55..d8e700faf 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -944,7 +944,6 @@ pub fn run_uint256_mul + ' gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 012dcab80..40cc0420d 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -778,7 +778,6 @@ pub fn run_weierstrass_add< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index d6400a2d7..8b3cfa32e 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -758,7 +758,6 @@ pub fn run_weierstrass_decompress< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 686baa397..f01b2727a 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -780,7 +780,6 @@ pub fn run_weierstrass_double< gkr_proof.clone(), &out_evals, &[], - &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index b16c81675..ae8651f96 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -3,7 +3,6 @@ use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; -use multilinear_extensions::mle::{IntoMLE, MultilinearExtension}; use p3::field::FieldAlgebra; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ @@ -87,7 +86,6 @@ pub struct PublicValues { pub heap_shard_len: u32, pub hint_start_addr: u32, pub hint_shard_len: u32, - pub public_io: Vec, pub public_io_digest: [u32; 8], pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } @@ -105,7 +103,6 @@ impl PublicValues { heap_shard_len: u32, hint_start_addr: u32, hint_shard_len: u32, - public_io: Vec, public_io_digest: [u32; 8], shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { @@ -120,7 +117,6 @@ impl PublicValues { heap_shard_len, hint_start_addr, hint_shard_len, - public_io, public_io_digest, shard_rw_sum, } @@ -158,30 +154,6 @@ impl PublicValues { _ => panic!("public value index {index} out of range"), } } - - pub fn mles(&self) -> Vec> { - // public_io is represented as UINT_LIMBS columns. - (0..UINT_LIMBS) - .map(|limb_index| { - let limb_values = self - .public_io - .iter() - .map(|value| { - E::BaseField::from_canonical_u16( - ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, - ) - }) - .collect_vec(); - - // Empty public_io means a constant-zero public input column. - if limb_values.is_empty() { - vec![E::BaseField::ZERO].into_mle() - } else { - limb_values.into_mle() - } - }) - .collect_vec() - } } /// Map circuit names to diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 85decd78e..929c7850d 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -842,8 +842,6 @@ impl> MainSumcheckProver> MainSumcheckProver> MainSumcheckProver { pub witness: Vec>>, pub structural_witness: Vec>>, pub fixed: Vec>>, - pub public_input: Vec>>, - pub pub_io_evals: Vec::BaseField, PB::E>>, + pub pi: Vec::BaseField, PB::E>>, pub num_instances: Vec, pub has_ecc_ops: bool, } @@ -154,7 +153,6 @@ pub trait TowerProver { pub struct MainSumcheckEvals { pub wits_in_evals: Vec, pub fixed_in_evals: Vec, - pub pi_in_evals: Vec, } pub trait MainSumcheckProver { diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 20f4da1b1..7842374fc 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -623,7 +623,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { left, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -638,7 +638,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { &right, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -670,7 +670,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -716,7 +716,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -762,7 +762,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { arg_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, wits_in, structural_witin, @@ -963,19 +963,13 @@ Hints: ) where E: LkMultiplicityKey, { - let all_pi_mles: Vec> = - pi.mles::().into_iter().map(|v| v.into()).collect_vec(); let get_circuit_pi_inputs = |circuit_cs: &ConstraintSystem| { let circuit_pub_io_evals = circuit_cs - .instance_values + .instance .iter() .map(|instance| Either::Right(E::from(pi.query_by_index::(instance.0)))) .collect_vec(); - let circuit_pi_mles = circuit_cs - .instance_openings - .iter() - .map(|instance| all_pi_mles[instance.0].clone()) - .collect_vec(); + let circuit_pi_mles = vec![]; (circuit_pub_io_evals, circuit_pi_mles) }; @@ -1089,7 +1083,7 @@ Hints: &expr.values, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, &fixed, &witness, &structural_witness, @@ -1104,7 +1098,7 @@ Hints: &expr.multiplicity, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, &fixed, &witness, &structural_witness, @@ -1196,7 +1190,7 @@ Hints: ram_type_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, @@ -1209,7 +1203,7 @@ Hints: w_rlc_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, @@ -1296,7 +1290,7 @@ Hints: ram_type_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, @@ -1309,7 +1303,7 @@ Hints: r_rlc_expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, @@ -1337,7 +1331,7 @@ Hints: expr, cs.num_witin, cs.num_fixed as WitnessId, - cs.instance_openings.len(), + 0, fixed, witness, structural_witness, @@ -1474,7 +1468,7 @@ Hints: let gs_init = GlobalState::initial_global_state(&mut cb).unwrap(); let gs_final = GlobalState::finalize_global_state(&mut cb).unwrap(); let gs_pub_io_evals = cs - .instance_values + .instance .iter() .map(|instance| E::from(pi.query_by_index::(instance.0))) .collect_vec(); diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 123af4dda..f7f35add7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -31,7 +31,6 @@ use crate::structs::ProvingKey; use crate::{ e2e::ShardContext, error::ZKVMError, - instructions::riscv::constants::UINT_LIMBS, scheme::{ hal::{DeviceProvingKey, ProofInput}, utils::build_main_witness, @@ -147,7 +146,6 @@ impl< .get_device_proving_key(shard_ctx) .map(|dpk| dpk.fixed_mles.clone()) .unwrap_or_default(); - let pi_mles_preload = pi.mles::(); info_span!( "[ceno] create_proof_of_shard", @@ -159,7 +157,7 @@ impl< // The order must match verifier and recursion verifier exactly. // TODO deal with vector-based public value to transcript for (_, circuit_pk) in self.pk.circuit_pks.iter() { - for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_values.iter() { + for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance.iter() { transcript.append_field_element(&pi.query_by_index::(instance_value.0)); } } @@ -272,10 +270,6 @@ impl< ]; tracing::debug!("global challenges in prover: {:?}", challenges); - let public_input_span = entered_span!("public_input", profiling_1 = true); - let public_input = self.device.transport_mles(&pi_mles_preload); - exit_span!(public_input_span); - let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); // Phase 1: Build all ChipTasks @@ -288,7 +282,6 @@ impl< &witness_data, fixed_mles, challenges, - public_input, &pi, &circuit_trace_indices, ); @@ -507,7 +500,6 @@ impl< let MainSumcheckEvals { wits_in_evals, fixed_in_evals, - pi_in_evals, } = evals; exit_span!(span); @@ -525,7 +517,6 @@ impl< MainSumcheckEvals { wits_in_evals, fixed_in_evals, - pi_in_evals, }, input_opening_point, )) @@ -543,7 +534,6 @@ impl< witness_data: &PB::PcsData, mut fixed_mles: Vec>>, challenges: [E; 2], - public_input: Vec>>, pi: &PublicValues, circuit_trace_indices: &[Option], ) -> Vec> { @@ -621,21 +611,10 @@ impl< }; let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); - let public_io = cs - .instance_openings() - .iter() - .map(|Instance(idx)| { - debug_assert!( - *idx < UINT_LIMBS, - "instance_opening index {idx} out of range" - ); - public_input[*idx].clone() - }) - .collect_vec(); - let pi_evals = cs + let circuit_pi = cs .zkvm_v1_css - .instance_values + .instance .iter() .map(|Instance(idx)| Either::Left(pi.query_by_index::(*idx))) .collect_vec(); @@ -644,14 +623,13 @@ impl< witness: witness_mle, fixed, structural_witness, - public_input: public_io, - pub_io_evals: pi_evals, + pi: circuit_pi, num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; // SAFETY: All Arcs in ProofInput contain 'static data: // - GPU path: `witness` and `structural_witness` are empty vecs (deferred extraction), - // `fixed` and `public_input` originate from `DeviceProvingKey<'static, PB>`. + // `fixed` originates from `DeviceProvingKey<'static, PB>`. // - CPU path: `witness_mle` may borrow non-'static data, but the CPU path always // uses sequential execution (never enters the concurrent scheduler), so the data // remains valid for the lifetime of `build_chip_tasks`'s caller. @@ -730,7 +708,6 @@ impl< evaluations.push(vec![ result.opening_evals.wits_in_evals, result.opening_evals.fixed_in_evals, - result.opening_evals.pi_in_evals, ]); } chip_proofs @@ -868,7 +845,6 @@ where let MainSumcheckEvals { wits_in_evals, fixed_in_evals, - pi_in_evals, } = evals; exit_span!(span); @@ -886,7 +862,6 @@ where MainSumcheckEvals { wits_in_evals, fixed_in_evals, - pi_in_evals, }, input_opening_point, )) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 1bed1023a..cde7b08d2 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -211,8 +211,7 @@ fn test_rw_lk_expression_combination() { fixed: vec![], witness: wits_in, structural_witness: structural_in, - public_input: vec![], - pub_io_evals: vec![], + pi: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; @@ -396,7 +395,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 8], [0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index aa2be21fe..8583e2710 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -359,12 +359,6 @@ pub fn build_main_witness< "assert circuit" ); - let pub_io_mles = cs - .instance_openings - .iter() - .map(|instance| input.public_input[instance.0].clone()) - .collect_vec(); - // check all witness size are power of 2 assert!( input @@ -372,7 +366,6 @@ pub fn build_main_witness< .iter() .chain(&input.structural_witness) .chain(&input.fixed) - .chain(&pub_io_mles) .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) ); @@ -387,8 +380,8 @@ pub fn build_main_witness< &input.witness, &input.structural_witness, &input.fixed, - &pub_io_mles, - &input.pub_io_evals, + &[], + &input.pi, challenges, ); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 28afc506b..6cfc871fe 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -56,7 +56,7 @@ impl> ZKVMVerifier fn split_input_opening_evals( circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, - ) -> Result<(Vec, Vec, Vec), ZKVMError> { + ) -> Result<(Vec, Vec), ZKVMError> { let cs = circuit_vk.get_cs(); let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { return Err(ZKVMError::InvalidProof("missing gkr proof".into())); @@ -68,8 +68,7 @@ impl> ZKVMVerifier let evals = &last_layer.main.evals; let wit_len = cs.num_witin(); let fixed_len = cs.num_fixed(); - let pi_len = cs.instance_openings().len(); - let min_len = wit_len + fixed_len + pi_len; + let min_len = wit_len + fixed_len; if evals.len() < min_len { return Err(ZKVMError::InvalidProof( format!( @@ -83,8 +82,7 @@ impl> ZKVMVerifier let wits_in_evals = evals[..wit_len].to_vec(); let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); - let pi_in_evals = evals[(wit_len + fixed_len)..(wit_len + fixed_len + pi_len)].to_vec(); - Ok((wits_in_evals, fixed_in_evals, pi_in_evals)) + Ok((wits_in_evals, fixed_in_evals)) } pub fn new(vk: ZKVMVerifyingKey) -> Self { @@ -218,7 +216,7 @@ impl> ZKVMVerifier // Global-state expressions are built from compact instance IDs // (query order), not absolute public-value indices. - let pi_evals = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + let pi = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] .into_iter() .map(|idx| E::from(vm_proof.public_values.query_by_index::(idx))) .collect_vec(); @@ -240,7 +238,7 @@ impl> ZKVMVerifier // Include transcript-visible public values in canonical circuit order. // This must match prover and recursion verifier exactly. for (_, circuit_vk) in self.vk.circuit_vks.iter() { - for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance.iter() { transcript.append_field_element( &vm_proof.public_values.query_by_index::(instance_value.0), ); @@ -461,7 +459,7 @@ impl> ZKVMVerifier &[], &[], &[], - &pi_evals, + &pi, &challenges, &self.vk.initial_global_state_expr, ) @@ -472,7 +470,7 @@ impl> ZKVMVerifier &[], &[], &[], - &pi_evals, + &pi, &challenges, &self.vk.finalize_global_state_expr, ) @@ -695,23 +693,16 @@ impl> ZKVMVerifier ] }; let pi = cs - .instance_values + .instance .iter() .map(|instance| E::from(public_values.query_by_index::(instance.0))) .collect_vec(); - let (wits_in_evals, fixed_in_evals, _pi_in_evals) = - Self::split_input_opening_evals(circuit_vk, proof)?; - let instance_mles = public_values - .mles::() - .into_iter() - .map(|mle| mle.get_base_field_vec().to_vec()) - .collect_vec(); + let (wits_in_evals, fixed_in_evals) = Self::split_input_opening_evals(circuit_vk, proof)?; let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), &evals, &pi, - &instance_mles, challenges, transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..3490c2f2c 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -155,8 +155,8 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } - pub fn instance_openings(&self) -> &[Instance] { - &self.zkvm_v1_css.instance_openings + pub fn instance(&self) -> &[Instance] { + &self.zkvm_v1_css.instance } pub fn has_ecc_ops(&self) -> bool { !self.zkvm_v1_css.ec_final_sum.is_empty() diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index e98a827f2..fb59fdbb7 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -763,21 +763,7 @@ mod tests { shard_rw_sum[i] = fe.as_canonical_u32(); } - let public_value = PublicValues::new( - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - vec![0], // dummy - [0; 8], - shard_rw_sum, - ); + let public_value = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], shard_rw_sum); // assign witness let witness = ShardRamCircuit::assign_instances( @@ -815,38 +801,23 @@ mod tests { let pub_io_evals = pk .get_cs() .zkvm_v1_css - .instance_values + .instance .iter() .map(|instance| Either::Right(E::from(public_value.query_by_index::(instance.0)))) .collect_vec(); - let pi_mles = public_value.mles::(); #[cfg(not(feature = "gpu"))] - let (witness_mles, structural_mles, public_input_mles) = { - let public_input_mles = pk - .get_cs() - .instance_openings() - .iter() - .map(|instance| Arc::new(pi_mles[instance.0].clone())) - .collect_vec(); + let (witness_mles, structural_mles) = { ( witness[0].to_mles().into_iter().map(Arc::new).collect(), witness[1].to_mles().into_iter().map(Arc::new).collect(), - public_input_mles, ) }; #[cfg(feature = "gpu")] - let (witness_mles, structural_mles, public_input_mles) = { + let (witness_mles, structural_mles) = { let cuda_hal = get_cuda_hal().unwrap(); let witness_cpu: Vec<_> = witness[0].to_mles(); let structural_cpu: Vec<_> = witness[1].to_mles(); - let public_cpu: Vec<_> = pk - .get_cs() - .instance_openings() - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .into_iter() - .collect_vec(); ( witness_cpu .iter() @@ -856,10 +827,6 @@ mod tests { .iter() .map(|v| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, v))) .collect_vec(), - public_cpu - .iter() - .map(|v| Arc::new(MultilinearExtensionGpu::from_ceno(&cuda_hal, v))) - .collect_vec(), ) }; @@ -867,8 +834,7 @@ mod tests { witness: witness_mles, structural_witness: structural_mles, fixed: vec![], - public_input: public_input_mles.clone(), - pub_io_evals, + pi: pub_io_evals, num_instances: vec![n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, }; diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 2b72f251e..10048418e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -44,8 +44,7 @@ impl Chip { + cb.cs.r_table_expressions.len() + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed - + cb.cs.num_witin as usize - + cb.cs.instance_openings.len(), + + cb.cs.num_witin as usize, final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 8c54fedaf..5dd617da4 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -103,8 +103,7 @@ pub struct ConstraintSystem { pub fixed_namespace_map: Vec, // record which public input index is involving in constraint computation - pub instance_values: Vec, - pub instance_openings: Vec, + pub instance: Vec, pub ec_point_exprs: Vec>, pub ec_slope_exprs: Vec>, @@ -177,8 +176,7 @@ impl ConstraintSystem { num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), - instance_values: vec![], - instance_openings: vec![], + instance: vec![], ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], @@ -265,26 +263,11 @@ impl ConstraintSystem { pub fn query_instance(&mut self, idx: usize) -> Result { let i = Instance(idx); assert!( - !self.instance_values.contains(&i), + !self.instance.contains(&i), "query same pubio idx {idx} value more than once", ); - self.instance_values.push(i); - Ok(Instance(self.instance_values.len() - 1)) - } - - pub fn query_instance_for_openings( - &mut self, - idx: usize, - ) -> Result { - let i = Instance(idx); - - assert!( - !self.instance_openings.contains(&i), - "query same pubio idx {idx} mle more than once", - ); - self.instance_openings.push(i); - - Ok(Instance(self.instance_openings.len() - 1)) + self.instance.push(i); + Ok(Instance(self.instance.len() - 1)) } pub fn rlc_chip_record(&self, items: Vec>) -> Expression { diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index a741775ff..c7dfc8cad 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -123,7 +123,6 @@ impl GKRCircuit { gkr_proof: GKRProof, out_evals: &[PointAndEval], pub_io_evals: &[E], - instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -145,7 +144,6 @@ impl GKRCircuit { layer_proof, &mut evaluations, pub_io_evals, - instance_mles, &mut challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 711458dbf..0a993f046 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; use multilinear_extensions::{ - Expression, Instance, StructuralWitIn, ToExpr, + Expression, StructuralWitIn, ToExpr, mle::{Point, PointAndEval}, monomial::Term, }; @@ -75,8 +75,6 @@ pub struct Layer { pub max_expr_degree: usize, /// keep all structural witin which could be evaluated succinctly without PCS pub structural_witins: Vec, - /// instance openings - pub instance_openings: Vec, /// num challenges dedicated to this layer. pub n_challenges: usize, /// Expressions to prove in this layer. For zerocheck and linear layers, @@ -158,7 +156,6 @@ impl Layer { ), expr_names: Vec, structural_witins: Vec, - instance_openings: Vec, ) -> Self { assert_eq!(expr_names.len(), exprs.len(), "there are expr without name"); let max_expr_degree = exprs @@ -178,7 +175,6 @@ impl Layer { n_instance, max_expr_degree, structural_witins, - instance_openings, n_challenges, exprs, exprs_with_selector_out_eval_monomial_form: vec![], @@ -258,7 +254,6 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], - instance_mles: &[Vec], challenges: &mut Vec, transcript: &mut Trans, selector_ctxs: &[SelectorContext], @@ -273,7 +268,6 @@ impl Layer { proof, eval_and_dedup_points, pub_io_evals, - instance_mles, challenges, transcript, selector_ctxs, @@ -495,7 +489,7 @@ impl Layer { } = &cb.cs; let in_eval_expr = (non_zero_expr_len..) - .take(cb.cs.num_witin as usize + cb.cs.num_fixed + cb.cs.instance_openings.len()) + .take(cb.cs.num_witin as usize + cb.cs.num_fixed) .collect_vec(); if rotations.is_empty() { Layer::new( @@ -504,7 +498,7 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - cb.cs.instance_openings.len(), + 0, expressions, n_challenges, in_eval_expr, @@ -512,7 +506,6 @@ impl Layer { ((None, vec![]), 0, 0), expr_names, cb.cs.structural_witins.clone(), - cb.cs.instance_openings.clone(), ) } else { let Some(RotationParams { @@ -529,7 +522,7 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - cb.cs.instance_openings.len(), + 0, expressions, n_challenges, in_eval_expr, @@ -541,7 +534,6 @@ impl Layer { ), expr_names, cb.cs.structural_witins.clone(), - cb.cs.instance_openings.clone(), ) } } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 7ac6ae8d4..8178a6b25 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -3,7 +3,7 @@ use itertools::{Itertools, chain, izip}; use multilinear_extensions::{ ChallengeId, Expression, StructuralWitIn, StructuralWitInType, ToExpr, WitnessId, macros::{entered_span, exit_span}, - mle::{IntoMLE, Point}, + mle::Point, monomial::Term, monomialize_expr_to_wit_terms, utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins}, @@ -77,7 +77,6 @@ pub trait ZerocheckLayer { proof: LayerProof, eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -228,7 +227,6 @@ impl ZerocheckLayer for Layer { proof: LayerProof, mut eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -386,27 +384,6 @@ impl ZerocheckLayer for Layer { } } - // check pub-io openings by evaluating the opened public-input MLEs. - let pubio_offset = self.n_witin + self.n_fixed; - for (index, instance) in self.instance_openings.iter().enumerate() { - let index = pubio_offset + index; - let poly = instance_mles - .get(instance.0) - .expect("instance opening index out of bounds for instance_mles") - .clone() - .into_mle(); - let expected_eval = poly.evaluate(&in_point[..poly.num_vars()]); - if expected_eval != main_evals[index] { - return Err(BackendError::LayerVerificationFailed( - format!("layer {} pi mismatch", self.name.clone()).into(), - VerifierError::ClaimNotMatch( - format!("{}", expected_eval).into(), - format!("{}", main_evals[index]).into(), - ), - )); - } - } - let got_claim = eval_by_expr_with_instance( &[], &main_evals, diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index 6bbb1cdc7..eb30aa87b 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -425,7 +425,6 @@ impl LayerConstraintSystem { ((None, vec![]), 0, 0), expr_names, vec![], - vec![], ) } else { let Some(RotationParams { @@ -454,7 +453,6 @@ impl LayerConstraintSystem { ), expr_names, vec![], - vec![], ) } } From b3aa1f7c512b19c1d7d6ee286a283965f4e2a9e6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 17:42:57 +0800 Subject: [PATCH 5/6] new api: commit ctx --- ceno_emul/src/syscalls.rs | 16 ++--- ceno_emul/src/syscalls/pubio_commit.rs | 10 ++- ceno_rt/src/lib.rs | 2 +- ceno_rt/src/mmio.rs | 95 ++++++++++++++++++++++++-- 4 files changed, 106 insertions(+), 17 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 547412bc7..368ea88cf 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -15,9 +15,11 @@ pub mod uint256; pub use ceno_syscall::{ BLS12381_ADD, BLS12381_DECOMPRESS, BLS12381_DOUBLE, BN254_ADD, BN254_DOUBLE, BN254_FP_ADD, BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, PHANTOM_LOG_PC_CYCLE, - SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, - SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, UINT256_MUL, + PUB_IO_COMMIT, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, + SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, + UINT256_MUL, }; +pub use pubio_commit::PubIoCommitSpec; pub trait SyscallSpec { const NAME: &'static str; @@ -31,14 +33,6 @@ pub trait SyscallSpec { const GKR_OUTPUTS: usize = 0; } -pub struct PubIoCommitSpec; -impl SyscallSpec for PubIoCommitSpec { - const NAME: &'static str = "PUB_IO_COMMIT"; - const REG_OPS_COUNT: usize = 1; - const MEM_OPS_COUNT: usize = 8; - const CODE: u32 = ceno_syscall::PUB_IO_COMMIT; -} - /// Trace the inputs and effects of a syscall. pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result { match function_code { @@ -58,7 +52,7 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< BN254_FP2_ADD => Ok(bn254::bn254_fp2_add(vm)), BN254_FP2_MUL => Ok(bn254::bn254_fp2_mul(vm)), UINT256_MUL => Ok(uint256::uint256_mul(vm)), - code if code == PubIoCommitSpec::CODE => Ok(pubio_commit::pubio_commit(vm)), + PUB_IO_COMMIT => Ok(pubio_commit::pubio_commit(vm)), // phantom syscall PHANTOM_LOG_PC_CYCLE => Ok(phantom::log_pc_cycle(vm)), diff --git a/ceno_emul/src/syscalls/pubio_commit.rs b/ceno_emul/src/syscalls/pubio_commit.rs index aa49281d8..f21a2109d 100644 --- a/ceno_emul/src/syscalls/pubio_commit.rs +++ b/ceno_emul/src/syscalls/pubio_commit.rs @@ -1,9 +1,17 @@ use crate::{Change, EmuContext, Platform, Tracer, VMState, WriteOp, utils::MemoryView}; -use super::{PubIoCommitSpec, SyscallEffects, SyscallSpec, SyscallWitness}; +use super::{SyscallEffects, SyscallSpec, SyscallWitness}; const PUBIO_COMMIT_WORDS: usize = 8; +pub struct PubIoCommitSpec; +impl SyscallSpec for PubIoCommitSpec { + const NAME: &'static str = "PUB_IO_COMMIT"; + const REG_OPS_COUNT: usize = 1; + const MEM_OPS_COUNT: usize = PUBIO_COMMIT_WORDS; + const CODE: u32 = ceno_syscall::PUB_IO_COMMIT; +} + /// Trace the PUB_IO_COMMIT syscall by reading 8 digest words from guest memory. pub fn pubio_commit(vm: &VMState) -> SyscallEffects { let digest_ptr = vm.peek_register(Platform::reg_arg0()); diff --git a/ceno_rt/src/lib.rs b/ceno_rt/src/lib.rs index d3c08b298..dd0d4871f 100644 --- a/ceno_rt/src/lib.rs +++ b/ceno_rt/src/lib.rs @@ -13,7 +13,7 @@ use std::{ mod allocator; mod mmio; -pub use mmio::{commit, read, read_owned, read_slice}; +pub use mmio::{CommitCtx, commit, commit_digest, read, read_owned, read_slice}; mod io; #[cfg(debug_assertions)] diff --git a/ceno_rt/src/mmio.rs b/ceno_rt/src/mmio.rs index cde0f42a7..e3dd90801 100644 --- a/ceno_rt/src/mmio.rs +++ b/ceno_rt/src/mmio.rs @@ -4,6 +4,7 @@ use ceno_serde::from_slice; use ceno_syscall::syscall_pub_io_commit; use core::{cell::UnsafeCell, ptr, slice::from_raw_parts}; use serde::de::DeserializeOwned; +use std::vec::Vec; use tiny_keccak::{Hasher, Keccak}; struct RegionState { @@ -110,13 +111,99 @@ fn digest_to_words(digest: [u8; 32]) -> [u32; 8] { }) } -/// Commit arbitrary public bytes by hashing with Keccak-256 and emitting digest limbs. -pub fn commit(data: &[u8]) { +fn keccak_words(bytes: &[u8]) -> [u32; 8] { let mut keccak = Keccak::v256(); - keccak.update(data); + keccak.update(bytes); let mut digest = [0u8; 32]; keccak.finalize(&mut digest); + digest_to_words(digest) +} - let digest_words = digest_to_words(digest); +/// Commit a precomputed public-io digest. +/// +/// The input must already be an 8-word Keccak-256 digest encoded in little-endian words. +pub fn commit_digest(digest_words: [u32; 8]) { syscall_pub_io_commit(&digest_words); } + +/// Accumulates committed bytes and emits one digest when finalized. +#[derive(Clone, Debug, Default)] +pub struct CommitCtx { + bytes: Vec, +} + +impl CommitCtx { + pub fn new() -> Self { + Self::default() + } + + fn digest_words(&self) -> [u32; 8] { + keccak_words(&self.bytes) + } + + /// Append arbitrary bytes to this context. + pub fn commit(&mut self, data: &[u8]) { + self.bytes.extend_from_slice(data); + } + + /// Compute a final digest by hashing the accumulated bytes once. + pub fn finalized(self) { + commit_digest(self.digest_words()) + } +} + +/// Commit arbitrary public bytes by hashing with Keccak-256 and emitting digest limbs. +pub fn commit(data: &[u8]) { + commit_digest(keccak_words(data)); +} + +#[cfg(test)] +mod tests { + use super::{CommitCtx, digest_to_words, keccak_words}; + use tiny_keccak::{Hasher, Keccak}; + + #[test] + fn keccak_words_matches_manual_conversion() { + let words = keccak_words(b"hello world"); + + let mut manual = Keccak::v256(); + manual.update(b"hello world"); + let mut digest = [0u8; 32]; + manual.finalize(&mut digest); + + assert_eq!(words, digest_to_words(digest)); + } + + #[test] + fn commit_ctx_digest_words_hashes_all_appended_bytes() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"abc"); + ctx.commit(b"123"); + + let got = ctx.digest_words(); + let expected = keccak_words(b"abc123"); + + assert_eq!(got, expected); + } + + #[test] + fn commit_ctx_commit_appends_raw_bytes_before_finalize() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"hello"); + ctx.commit(b" "); + ctx.commit(b"world"); + + let got = ctx.digest_words(); + let expected = keccak_words(b"hello world"); + + assert_eq!(got, expected); + } + + #[test] + #[should_panic(expected = "syscall_pub_io_commit should only run inside zkvm")] + fn commit_ctx_finalized_is_callable() { + let mut ctx = CommitCtx::new(); + ctx.commit(b"payload"); + ctx.finalized(); + } +} From 0f499a17a479b781b42e1600cb92f79443f9ec9e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 27 Mar 2026 19:39:32 +0800 Subject: [PATCH 6/6] misc: fix recursion verifier --- ceno_recursion/src/zkvm_verifier/binding.rs | 26 +-------------- ceno_zkvm/src/scheme.rs | 32 +++++++++++++++++++ ceno_zkvm/src/scheme/tests.rs | 35 ++++++++++++++++++++- 3 files changed, 67 insertions(+), 26 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 08a41e2d9..78f4c3c99 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -41,30 +41,6 @@ pub type E = BinomialExtensionField; pub type RecPcs = Basefold; pub type InnerConfig = AsmConfig; -fn pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec { - vec![ - F::from_canonical_u32(public_values.exit_code & 0xffff), - F::from_canonical_u32((public_values.exit_code >> 16) & 0xffff), - F::from_canonical_u32(public_values.init_pc), - F::from_canonical_u64(public_values.init_cycle), - F::from_canonical_u32(public_values.end_pc), - F::from_canonical_u64(public_values.end_cycle), - F::from_canonical_u32(public_values.shard_id), - F::from_canonical_u32(public_values.heap_start_addr), - F::from_canonical_u32(public_values.heap_shard_len), - F::from_canonical_u32(public_values.hint_start_addr), - F::from_canonical_u32(public_values.hint_shard_len), - ] - .into_iter() - .chain( - public_values - .shard_rw_sum - .iter() - .map(|value| F::from_canonical_u32(*value)), - ) - .collect_vec() -} - pub fn decompose_minus_one_bits(n: usize) -> Vec { let a = if n > 0 { n - 1 } else { 0 }; let mut bit_decomp: Vec = vec![]; @@ -129,7 +105,7 @@ impl ZKVMProofInput { zkvm_proof: ZKVMProof, vk: &ZKVMVerifyingKey, ) -> Self { - let pi = pi_from_public_values(&zkvm_proof.public_values); + let pi = zkvm_proof.public_values.iter_field::().collect_vec(); let mut chip_witin_num_vars: HashMap = HashMap::new(); // (chip_id, (num_witin, num_fixed)) let mut chip_indices = zkvm_proof diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index ae8651f96..66e51912b 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -91,6 +91,38 @@ pub struct PublicValues { } impl PublicValues { + pub const fn flattened_len() -> usize { + PUBIO_DIGEST_IDX + PUBIO_DIGEST_U16_LIMBS + } + + pub fn iter_field<'a, Base: FieldAlgebra + 'a>(&'a self) -> impl Iterator + 'a { + [ + Base::from_canonical_u32(self.exit_code & 0xffff), + Base::from_canonical_u32((self.exit_code >> 16) & 0xffff), + Base::from_canonical_u32(self.init_pc), + Base::from_canonical_u64(self.init_cycle), + Base::from_canonical_u32(self.end_pc), + Base::from_canonical_u64(self.end_cycle), + Base::from_canonical_u32(self.shard_id), + Base::from_canonical_u32(self.heap_start_addr), + Base::from_canonical_u32(self.heap_shard_len), + Base::from_canonical_u32(self.hint_start_addr), + Base::from_canonical_u32(self.hint_shard_len), + ] + .into_iter() + .chain( + self.shard_rw_sum + .iter() + .map(|value| Base::from_canonical_u32(*value)), + ) + .chain(self.public_io_digest.iter().flat_map(|word| { + [ + Base::from_canonical_u32(word & 0xffff), + Base::from_canonical_u32((word >> 16) & 0xffff), + ] + })) + } + #[allow(clippy::too_many_arguments)] pub fn new( exit_code: u32, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index cde7b08d2..2f5fe8ca2 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -34,7 +34,7 @@ use ff_ext::{Instrumented, PoseidonField}; use super::{ PublicValues, - constants::MAX_NUM_VARIABLES, + constants::{MAX_NUM_VARIABLES, SEPTIC_EXTENSION_DEGREE}, prover::ZKVMProver, utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, @@ -52,6 +52,39 @@ use p3::field::FieldAlgebra; use rand::thread_rng; use transcript::{BasicTranscript, Transcript}; +#[test] +fn test_public_values_iter_field_matches_query_order() { + type E = GoldilocksExt2; + + let public_values = PublicValues::new( + 0xABCD_1234, + 0x0800_0000, + 123, + 0x0800_1000, + 456, + 7, + 0x3000_0000, + 64, + 0x2800_0000, + 32, + [0, 1, 2, 3, 4, 5, 6, 7], + std::array::from_fn(|i| (i as u32) + 10), + ); + + let from_iter = public_values + .iter_field::<::BaseField>() + .collect_vec(); + let from_query = (0..PublicValues::flattened_len()) + .map(|i| public_values.query_by_index::(i)) + .collect_vec(); + + assert_eq!( + PublicValues::flattened_len(), + 11 + SEPTIC_EXTENSION_DEGREE * 2 + 16 + ); + assert_eq!(from_iter, from_query); +} + struct TestConfig { pub(crate) reg_id: WitIn, }