From ab83b35fc1e6e74bf12c961a8da862663c0b1c09 Mon Sep 17 00:00:00 2001 From: erhant Date: Sun, 28 Jun 2026 21:56:27 +0300 Subject: [PATCH 1/2] feat: CLI `--backend` override and prime-aware witness writing, add `Justfile` --- Justfile | 35 ++++++++++++ crates/circomkit-cli/src/main.rs | 12 +++- crates/circomkit-core/src/enums.rs | 56 +++++++++++++++++++ crates/circomkit-core/src/utils/mod.rs | 2 +- crates/circomkit-core/src/utils/primes.rs | 26 +++++++++ crates/circomkit-core/src/utils/witness.rs | 54 ++++++++++++++---- .../src/lambdaworks/convert.rs | 8 +-- crates/circomkit/src/circomkit/compile.rs | 16 +++--- crates/circomkit/src/circomkit/prove.rs | 16 ++++-- crates/circomkit/src/lib.rs | 2 +- crates/circomkit/tests/e2e/prove.rs | 43 ++++++++++++-- 11 files changed, 233 insertions(+), 37 deletions(-) create mode 100644 Justfile diff --git a/Justfile b/Justfile new file mode 100644 index 0000000..57deed3 --- /dev/null +++ b/Justfile @@ -0,0 +1,35 @@ +## Run `just` to see recipes. + +# List available recipes +default: + @just --list + +# Format all crates +fmt: + cargo fmt --all + +# Check formatting without modifying files (CI-friendly) +fmt-check: + cargo fmt --all --check + +# Lint with clippy across the workspace +lint: + cargo clippy --workspace --all-targets + +# Lint including the native proving backends +lint-all: + cargo clippy --workspace --all-targets --features "prove-arkworks,prove-lambdaworks" + +# Run the test suite (e2e needs `circom` + `snarkjs` on PATH) +test: + cargo test --workspace + +# Run tests including the native proving backends +test-all: + cargo test --workspace --features "prove-arkworks,prove-lambdaworks" + +# Format, lint, and test — the pre-commit gate +check: fmt lint test + +# Same as `check` but with native backends enabled +check-all: fmt lint-all test-all diff --git a/crates/circomkit-cli/src/main.rs b/crates/circomkit-cli/src/main.rs index 8b52a1c..43e3006 100644 --- a/crates/circomkit-cli/src/main.rs +++ b/crates/circomkit-cli/src/main.rs @@ -4,6 +4,7 @@ use anyhow::{Context, Result}; use clap::{Parser, Subcommand}; use circomkit::Circomkit; +use circomkit::ProvingBackendKind; #[derive(Parser)] #[command( @@ -94,6 +95,9 @@ enum Commands { circuit: String, /// Input name input: String, + /// Override the proving backend (snarkjs, arkworks, lambdaworks) + #[arg(long)] + backend: Option, }, /// Verify a proof @@ -243,8 +247,12 @@ fn main() -> Result<()> { println!("witness: {}", path.display()); } - Commands::Prove { circuit, input } => { - let path = ck.prove(&circuit, &input, None)?; + Commands::Prove { + circuit, + input, + backend, + } => { + let path = ck.prove(&circuit, &input, None, backend)?; println!("proof: {}", path.display()); } diff --git a/crates/circomkit-core/src/enums.rs b/crates/circomkit-core/src/enums.rs index b530b00..61fdf91 100644 --- a/crates/circomkit-core/src/enums.rs +++ b/crates/circomkit-core/src/enums.rs @@ -69,6 +69,31 @@ pub enum ProvingBackendKind { Lambdaworks, } +impl fmt::Display for ProvingBackendKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Snarkjs => write!(f, "snarkjs"), + Self::Arkworks => write!(f, "arkworks"), + Self::Lambdaworks => write!(f, "lambdaworks"), + } + } +} + +impl std::str::FromStr for ProvingBackendKind { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "snarkjs" => Ok(Self::Snarkjs), + "arkworks" => Ok(Self::Arkworks), + "lambdaworks" => Ok(Self::Lambdaworks), + other => Err(format!( + "unknown proving backend '{other}' (expected: snarkjs, arkworks, lambdaworks)" + )), + } + } +} + /// Log level for circomkit operations. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "lowercase")] @@ -95,3 +120,34 @@ impl LogLevel { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn proving_backend_kind_from_str_roundtrips() { + for kind in [ + ProvingBackendKind::Snarkjs, + ProvingBackendKind::Arkworks, + ProvingBackendKind::Lambdaworks, + ] { + let parsed: ProvingBackendKind = kind.to_string().parse().unwrap(); + assert_eq!(parsed, kind); + } + // Case-insensitive parsing. + assert_eq!( + "ARKWORKS".parse::().unwrap(), + ProvingBackendKind::Arkworks + ); + } + + #[test] + fn proving_backend_kind_from_str_rejects_unknown() { + let err = "plonky2".parse::().unwrap_err(); + assert!( + err.contains("plonky2"), + "error should name the bad input: {err}" + ); + } +} diff --git a/crates/circomkit-core/src/utils/mod.rs b/crates/circomkit-core/src/utils/mod.rs index c1bcad7..d9b7089 100644 --- a/crates/circomkit-core/src/utils/mod.rs +++ b/crates/circomkit-core/src/utils/mod.rs @@ -3,7 +3,7 @@ mod ptau; mod r1cs; mod witness; -pub use primes::{prime_from_value, prime_value}; +pub use primes::{prime_field_n8, prime_from_value, prime_value}; pub use ptau::{download_ptau, ptau_name_for_constraints, ptau_path_if_exists}; pub use r1cs::{parse_r1cs_bytes, parse_r1cs_info, read_r1cs_file, read_r1cs_info}; pub use witness::{ diff --git a/crates/circomkit-core/src/utils/primes.rs b/crates/circomkit-core/src/utils/primes.rs index fae66aa..9c30e16 100644 --- a/crates/circomkit-core/src/utils/primes.rs +++ b/crates/circomkit-core/src/utils/primes.rs @@ -41,6 +41,19 @@ pub fn prime_value(prime: Prime) -> BigUint { } } +/// Returns the number of bytes used to encode a field element of the given prime +/// in Circom's binary `.wtns`/`.r1cs` formats (the `n8` field). +/// +/// Matches snarkjs/ffjavascript: the modulus bit-length rounded up to a whole +/// number of 64-bit words, times 8. e.g. bn128 (254 bits) → 32, goldilocks +/// (64 bits) → 8. +/// +/// TODO: add link here +pub fn prime_field_n8(prime: Prime) -> u32 { + let bits = prime_value(prime).bits(); + (((bits - 1) / 64 + 1) * 8) as u32 +} + /// Attempts to identify a `Prime` variant from its field value. pub fn prime_from_value(value: &BigUint) -> Option { let primes = [ @@ -81,4 +94,17 @@ mod tests { let unknown = BigUint::from(42u32); assert_eq!(prime_from_value(&unknown), None); } + + #[test] + fn field_n8_matches_snarkjs_word_alignment() { + // 254/255-bit curves pack into 32 bytes (4 64-bit words). + assert_eq!(prime_field_n8(Prime::Bn128), 32); + assert_eq!(prime_field_n8(Prime::Bls12381), 32); + assert_eq!(prime_field_n8(Prime::Grumpkin), 32); + assert_eq!(prime_field_n8(Prime::Pallas), 32); + assert_eq!(prime_field_n8(Prime::Vesta), 32); + assert_eq!(prime_field_n8(Prime::Secq256r1), 32); + // Goldilocks is a 64-bit field — a single word. + assert_eq!(prime_field_n8(Prime::Goldilocks), 8); + } } diff --git a/crates/circomkit-core/src/utils/witness.rs b/crates/circomkit-core/src/utils/witness.rs index 17ff0b3..851f234 100644 --- a/crates/circomkit-core/src/utils/witness.rs +++ b/crates/circomkit-core/src/utils/witness.rs @@ -4,7 +4,7 @@ use std::path::Path; use num_bigint::BigInt; use num_traits::Signed; -use super::primes::prime_value; +use super::primes::{prime_field_n8, prime_value}; use crate::error::{CoreError, Result}; use crate::types::Witness; @@ -98,15 +98,14 @@ pub fn parse_witness_to_elems( )) } -/// Write a witness to a binary `.wtns` file (version 2, BN128 prime). -/// -/// Uses 32 bytes per field element (n8 = 32), which is standard for BN128/BLS12-381. -/// -/// TODO: take prime as argument -pub fn write_witness_file(path: &Path, witness: &Witness) -> Result<()> { - let n8: u32 = 32; - let prime = prime_value(crate::enums::Prime::Bn128); - let prime_bytes = prime.to_bytes_le(); +/// Write a witness to a binary `.wtns` file (version 2) for the given prime. +pub fn write_witness_file( + path: &Path, + witness: &Witness, + prime: crate::enums::Prime, +) -> Result<()> { + let n8 = prime_field_n8(prime); + let prime_bytes = prime_value(prime).to_bytes_le(); let witness_count = witness.len() as u32; @@ -155,3 +154,38 @@ pub fn write_witness_file(path: &Path, witness: &Witness) -> Result<()> { std::fs::write(path, buf)?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::enums::Prime; + + /// Extract the `n8` field from section 1 of a written witness file. + /// Layout: header (12 bytes) + section id (4) + section length (8) → n8 at offset 24. + fn n8_from_bytes(buf: &[u8]) -> u32 { + u32::from_le_bytes(buf[24..28].try_into().unwrap()) + } + + #[test] + fn writes_n8_per_prime_and_roundtrips() { + let dir = std::env::temp_dir().join("circomkit_wtns_n8"); + std::fs::create_dir_all(&dir).unwrap(); + let witness = vec![BigInt::from(1), BigInt::from(42), BigInt::from(255)]; + + // bn128 packs elements into 32 bytes. + let bn = dir.join("bn128.wtns"); + write_witness_file(&bn, &witness, Prime::Bn128).unwrap(); + let bytes = std::fs::read(&bn).unwrap(); + assert_eq!(n8_from_bytes(&bytes), 32); + assert_eq!(parse_witness_bytes(&bytes).unwrap(), witness); + + // goldilocks packs elements into 8 bytes — the old hardcoded 32 would be wrong. + let gl = dir.join("goldilocks.wtns"); + write_witness_file(&gl, &witness, Prime::Goldilocks).unwrap(); + let bytes = std::fs::read(&gl).unwrap(); + assert_eq!(n8_from_bytes(&bytes), 8); + assert_eq!(parse_witness_bytes(&bytes).unwrap(), witness); + + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/crates/circomkit-prove/src/lambdaworks/convert.rs b/crates/circomkit-prove/src/lambdaworks/convert.rs index 3787a54..b6bc1c0 100644 --- a/crates/circomkit-prove/src/lambdaworks/convert.rs +++ b/crates/circomkit-prove/src/lambdaworks/convert.rs @@ -24,8 +24,8 @@ fn field_to_dec(elem: &FieldElement) -> String { pub fn proof_to_snarkjs_json(proof: &Proof) -> serde_json::Value { serde_json::json!({ "pi_a": [ - field_to_dec(&proof.pi1.x()), - field_to_dec(&proof.pi1.y()), + field_to_dec(proof.pi1.x()), + field_to_dec(proof.pi1.y()), "1" ], "pi_b": [ @@ -40,8 +40,8 @@ pub fn proof_to_snarkjs_json(proof: &Proof) -> serde_json::Value { ["1", "0"] ], "pi_c": [ - field_to_dec(&proof.pi3.x()), - field_to_dec(&proof.pi3.y()), + field_to_dec(proof.pi3.x()), + field_to_dec(proof.pi3.y()), "1" ], "protocol": "groth16", diff --git a/crates/circomkit/src/circomkit/compile.rs b/crates/circomkit/src/circomkit/compile.rs index 65801c9..c0a6a35 100644 --- a/crates/circomkit/src/circomkit/compile.rs +++ b/crates/circomkit/src/circomkit/compile.rs @@ -32,17 +32,17 @@ impl Circomkit { }; // Check source .circom file mtime - if let Ok(source_mtime) = source_path.metadata().and_then(|m| m.modified()) { - if source_mtime > r1cs_mtime { - return false; - } + if let Ok(source_mtime) = source_path.metadata().and_then(|m| m.modified()) + && source_mtime > r1cs_mtime + { + return false; } // Check generated main file mtime - if let Ok(main_mtime) = main_path.metadata().and_then(|m| m.modified()) { - if main_mtime > r1cs_mtime { - return false; - } + if let Ok(main_mtime) = main_path.metadata().and_then(|m| m.modified()) + && main_mtime > r1cs_mtime + { + return false; } true diff --git a/crates/circomkit/src/circomkit/prove.rs b/crates/circomkit/src/circomkit/prove.rs index e0be8ce..e59d676 100644 --- a/crates/circomkit/src/circomkit/prove.rs +++ b/crates/circomkit/src/circomkit/prove.rs @@ -97,6 +97,8 @@ impl Circomkit { input: &str, data: Option<&CircuitSignals>, ) -> Result { + let prime = self.resolve(circuit)?.compiler.prime; + let input_data = match data { Some(d) => d.clone(), None => self.load_input(circuit, input)?, @@ -114,7 +116,7 @@ impl Circomkit { if let Some(parent) = wtns_path.parent() { std::fs::create_dir_all(parent)?; } - circomkit_core::utils::write_witness_file(&wtns_path, &witness)?; + circomkit_core::utils::write_witness_file(&wtns_path, &witness, prime)?; log::info!("witness computed for {circuit}/{input}"); Ok(wtns_path) @@ -122,19 +124,21 @@ impl Circomkit { /// Generate a proof for a circuit with the given input. /// - /// The proving backend is selected from the resolved circuit config - /// (`prover.backend`), capability-checked against the protocol and curve. - /// snarkjs uses its one-shot `full_prove`; native backends (arkworks, - /// lambdaworks) compute a witness first, then prove from it. + /// The proving backend is taken from `backend` when provided (e.g. the CLI + /// `--backend` flag), otherwise from the resolved circuit config + /// (`prover.backend`). Either way it is capability-checked against the + /// protocol and curve. snarkjs uses its one-shot `full_prove`; native + /// backends (arkworks, lambdaworks) compute a witness first, then prove from it. pub fn prove( &self, circuit: &str, input: &str, data: Option<&CircuitSignals>, + backend: Option, ) -> Result { let resolved = self.resolve(circuit)?; let protocol = resolved.prover.protocol; - let kind = resolved.prover.backend; + let kind = backend.unwrap_or(resolved.prover.backend); let prime = resolved.compiler.prime; let pkey_path = self.paths.pkey(circuit, protocol); diff --git a/crates/circomkit/src/lib.rs b/crates/circomkit/src/lib.rs index 758517b..9016f19 100644 --- a/crates/circomkit/src/lib.rs +++ b/crates/circomkit/src/lib.rs @@ -9,7 +9,7 @@ pub use circomkit::Circomkit; // Re-export key types for convenience pub use circomkit_core::config::{CircomkitConfig, CircuitConfig}; -pub use circomkit_core::enums::{Prime, Protocol}; +pub use circomkit_core::enums::{Prime, Protocol, ProvingBackendKind}; pub use circomkit_core::error::CoreError; pub use circomkit_core::pathing::CircomkitPaths; pub use circomkit_core::types::R1CSInfo; diff --git a/crates/circomkit/tests/e2e/prove.rs b/crates/circomkit/tests/e2e/prove.rs index 7bada7d..d26cd08 100644 --- a/crates/circomkit/tests/e2e/prove.rs +++ b/crates/circomkit/tests/e2e/prove.rs @@ -16,7 +16,7 @@ fn full_prove_and_verify() { let input = signals! { "in" => vec![2_i64, 4, 10] }; let proof_path = ck - .prove("multiplier_3", "prove_test", Some(&input)) + .prove("multiplier_3", "prove_test", Some(&input), None) .unwrap(); assert!(proof_path.exists()); @@ -30,7 +30,7 @@ fn verify_rejects_tampered_signals() { setup_multiplier(&ck); let input = signals! { "in" => vec![2_i64, 4, 10] }; - ck.prove("multiplier_3", "tamper_test", Some(&input)) + ck.prove("multiplier_3", "tamper_test", Some(&input), None) .unwrap(); // Tamper with the public signals file @@ -52,7 +52,7 @@ fn prove_rejects_unsupported_backend_curve() { let input = signals! { "in" => vec![2_i64, 4, 10] }; let err = ck - .prove("multiplier_3", "cap_test", Some(&input)) + .prove("multiplier_3", "cap_test", Some(&input), None) .expect_err("lambdaworks on bn128 must be rejected"); let msg = err.to_string().to_lowercase(); @@ -80,7 +80,9 @@ fn arkworks_backend_through_orchestrator() { ck.config.prover.backend = ProvingBackendKind::Arkworks; let input = signals! { "in" => vec![2_i64, 4, 10] }; - let proof_path = ck.prove("multiplier_3", "ark_e2e", Some(&input)).unwrap(); + let proof_path = ck + .prove("multiplier_3", "ark_e2e", Some(&input), None) + .unwrap(); assert!(proof_path.exists()); // Same proving key snarkjs set up + snarkjs-format proof => snarkjs verifies it. @@ -88,6 +90,37 @@ fn arkworks_backend_through_orchestrator() { assert!(ok, "snarkjs should verify the arkworks-generated proof"); } +/// An explicit backend override (as the CLI `--backend` flag passes) wins over +/// the resolved `prover.backend`. The config defaults to snarkjs (bn128), but +/// overriding to lambdaworks (bls12381-only) surfaces the curve capability error +/// before any proving work — proving the override, not the config, was used. +#[test] +fn prove_backend_override_beats_config() { + use circomkit::core::enums::ProvingBackendKind; + + let (ck, _guard) = test_circomkit(); + + let input = signals! { "in" => vec![2_i64, 4, 10] }; + let err = ck + .prove( + "multiplier_3", + "override_test", + Some(&input), + Some(ProvingBackendKind::Lambdaworks), + ) + .expect_err("override to lambdaworks on bn128 must be rejected"); + + let msg = err.to_string().to_lowercase(); + assert!( + msg.contains("lambdaworks"), + "error should name the overridden backend: {msg}" + ); + assert!( + msg.contains("curve"), + "error should mention the curve: {msg}" + ); +} + #[test] fn witness_then_prove() { let (ck, _guard) = test_circomkit(); @@ -102,7 +135,7 @@ fn witness_then_prove() { // Then prove (will recompute witness via snarkjs, but the file exists) let proof_path = ck - .prove("multiplier_3", "witness_test", Some(&input)) + .prove("multiplier_3", "witness_test", Some(&input), None) .unwrap(); assert!(proof_path.exists()); From 4c3f3bc6da070844cf62cc57b78ab2a50deed67a Mon Sep 17 00:00:00 2001 From: erhant Date: Tue, 30 Jun 2026 22:35:00 +0300 Subject: [PATCH 2/2] fix: copilot comments --- Cargo.lock | 1 + crates/circomkit-core/Cargo.toml | 3 +++ crates/circomkit-core/src/utils/witness.rs | 15 +++++++-------- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3029d5d..4849e88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -661,6 +661,7 @@ dependencies = [ "schemars 0.8.22", "serde", "serde_json", + "tempfile", "thiserror 2.0.18", "ureq", ] diff --git a/crates/circomkit-core/Cargo.toml b/crates/circomkit-core/Cargo.toml index 0383532..a891f86 100644 --- a/crates/circomkit-core/Cargo.toml +++ b/crates/circomkit-core/Cargo.toml @@ -22,3 +22,6 @@ download = ["dep:ureq"] [dependencies.ureq] workspace = true optional = true + +[dev-dependencies] +tempfile.workspace = true diff --git a/crates/circomkit-core/src/utils/witness.rs b/crates/circomkit-core/src/utils/witness.rs index 851f234..0690919 100644 --- a/crates/circomkit-core/src/utils/witness.rs +++ b/crates/circomkit-core/src/utils/witness.rs @@ -168,24 +168,23 @@ mod tests { #[test] fn writes_n8_per_prime_and_roundtrips() { - let dir = std::env::temp_dir().join("circomkit_wtns_n8"); - std::fs::create_dir_all(&dir).unwrap(); + let dir = tempfile::tempdir().unwrap(); + + // create a fake witness let witness = vec![BigInt::from(1), BigInt::from(42), BigInt::from(255)]; - // bn128 packs elements into 32 bytes. - let bn = dir.join("bn128.wtns"); + // bn128 packs elements into 32 bytes + let bn = dir.path().join("bn128.wtns"); write_witness_file(&bn, &witness, Prime::Bn128).unwrap(); let bytes = std::fs::read(&bn).unwrap(); assert_eq!(n8_from_bytes(&bytes), 32); assert_eq!(parse_witness_bytes(&bytes).unwrap(), witness); - // goldilocks packs elements into 8 bytes — the old hardcoded 32 would be wrong. - let gl = dir.join("goldilocks.wtns"); + // goldilocks packs elements into 8 bytes + let gl = dir.path().join("goldilocks.wtns"); write_witness_file(&gl, &witness, Prime::Goldilocks).unwrap(); let bytes = std::fs::read(&gl).unwrap(); assert_eq!(n8_from_bytes(&bytes), 8); assert_eq!(parse_witness_bytes(&bytes).unwrap(), witness); - - let _ = std::fs::remove_dir_all(&dir); } }