diff --git a/Cargo.lock b/Cargo.lock index 3d36b3144..809280d61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1001,6 +1001,9 @@ dependencies = [ "rand_chacha", "rayon", "serde", + "tracing", + "tracing-flame", + "tracing-subscriber", "transcript", ] diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 45d0b163f..905a04f09 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1025,24 +1025,25 @@ impl> ZKVMProver { .collect_vec(); // TODO implement mechanism to skip commitment - let span = entered_span!("pcs_opening"); let (fixed_opening_proof, fixed_commit) = if !fixed.is_empty() { - ( - Some( - PCS::simple_batch_open( - pp, - &fixed, - circuit_pk.fixed_commit_wd.as_ref().unwrap(), - &input_open_point, - fixed_in_evals.as_slice(), - transcript, - ) - .map_err(ZKVMError::PCSError)?, - ), - Some(PCS::get_pure_commitment( - circuit_pk.fixed_commit_wd.as_ref().unwrap(), - )), + let span = entered_span!("pcs_fixed_opening"); + let fixed_opening_proof = PCS::simple_batch_open( + pp, + &fixed, + circuit_pk.fixed_commit_wd.as_ref().unwrap(), + &input_open_point, + fixed_in_evals.as_slice(), + transcript, ) + .map_err(ZKVMError::PCSError)?; + exit_span!(span); + + let span = entered_span!("pcs_fixed_commitment"); + let fixed_commitment = + PCS::get_pure_commitment(circuit_pk.fixed_commit_wd.as_ref().unwrap()); + exit_span!(span); + + (Some(fixed_opening_proof), Some(fixed_commitment)) } else { (None, None) }; @@ -1055,6 +1056,7 @@ impl> ZKVMProver { fixed_in_evals, fixed_commit, ); + let span = entered_span!("pcs_witin_opening"); let wits_opening_proof = PCS::simple_batch_open( pp, &witnesses, @@ -1065,7 +1067,9 @@ impl> ZKVMProver { ) .map_err(ZKVMError::PCSError)?; exit_span!(span); + let span = entered_span!("pcs_witin_commitment"); let wits_commit = PCS::get_pure_commitment(&wits_commit); + exit_span!(span); tracing::debug!( "[table {}] build opening proof for {} polys at {:?}: values = {:?}, commit = {:?}", name, diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index 26ee123ba..8ac4aef06 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -22,6 +22,9 @@ plonky2.workspace = true poseidon.workspace = true rand.workspace = true rand_chacha.workspace = true +tracing.workspace = true +tracing-flame.workspace = true +tracing-subscriber.workspace = true rayon = { workspace = true, optional = true } serde.workspace = true transcript = { path = "../transcript" } diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index a733dd82a..43eae1e34 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -6,12 +6,15 @@ use super::{ sum_check_last_round, }, }; -use crate::util::{ - arithmetic::{interpolate_over_boolean_hypercube, interpolate2_weights}, - field_type_index_ext, field_type_iter_ext, - hash::write_digest_to_transcript, - log2_strict, - merkle_tree::MerkleTree, +use crate::{ + entered_span, exit_span, + util::{ + arithmetic::{interpolate_over_boolean_hypercube, interpolate2_weights}, + field_type_index_ext, field_type_iter_ext, + hash::write_digest_to_transcript, + log2_strict, + merkle_tree::MerkleTree, + }, }; use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; @@ -361,14 +364,18 @@ pub fn simple_batch_commit_phase>( where E::BaseField: Serialize + DeserializeOwned, { - let timer = start_timer!(|| "Simple batch commit phase"); + // let timer = start_timer!(|| "Simple batch commit phase"); + let span = entered_span!("Simple batch commit phase"); assert_eq!(point.len(), num_vars); assert_eq!(comm.num_polys, batch_coeffs.len()); - let prepare_timer = start_timer!(|| "Prepare"); + // let prepare_timer = start_timer!(|| "Prepare"); + let prepare_span = entered_span!("Simple batch commit phase"); let mut trees = Vec::with_capacity(num_vars); - let batch_codewords_timer = start_timer!(|| "Batch codewords"); + // let batch_codewords_timer = start_timer!(|| "Batch codewords"); + let batch_codewords_timer_span = entered_span!("Simple batch commit phase"); let mut running_oracle = comm.batch_codewords(batch_coeffs); - end_timer!(batch_codewords_timer); + // end_timer!(batch_codewords_timer); + exit_span!(batch_codewords_timer_span); let nthreads = max_usable_threads(); let per_thread_size = (1 << num_vars).div_ceil(&nthreads); @@ -383,27 +390,35 @@ where .sum() }) .collect::>(); - end_timer!(prepare_timer); + // end_timer!(prepare_timer); + exit_span!(prepare_span); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let build_eq_timer = start_timer!(|| "Basefold::build eq"); + let build_eq_span = entered_span!("mpcs::build_eq_span"); + // let build_eq_timer = start_timer!(|| "Basefold::build eq"); let mut eq = build_eq_x_r_vec(point); - end_timer!(build_eq_timer); + // end_timer!(build_eq_timer); + exit_span!(build_eq_span); - let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits"); + // let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits"); + let reverse_bits_span = entered_span!("mpcs::reverse_bits_span"); reverse_index_bits_in_place(&mut eq); - end_timer!(reverse_bits_timer); + // end_timer!(reverse_bits_timer); + exit_span!(reverse_bits_span); - let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); + let sumcheck_span = entered_span!("mpcs::sumcheck_first_round"); + // let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round"); let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals); - end_timer!(sumcheck_timer); + // end_timer!(sumcheck_timer); + exit_span!(sumcheck_span); let mut sumcheck_messages = Vec::with_capacity(num_rounds); let mut roots = Vec::with_capacity(num_rounds - 1); let mut final_message = Vec::new(); let mut running_tree_inner = Vec::new(); for i in 0..num_rounds { - let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); + // let sumcheck_timer = start_timer!(|| format!("Basefold round {}", i)); + let sumcheck_span = entered_span!("mpcs::sumcheck_round"); // For the first round, no need to send the running root, because this root is // committing to a vector that can be recovered from linearly combining other // already-committed vectors. @@ -414,6 +429,7 @@ where .get_and_append_challenge(b"commit round") .elements; + let inner_span = entered_span!("mpcs::fri_compute_new_running_oracle"); // Fold the current oracle for FRI let new_running_oracle = basefold_one_round_by_interpolation_weights::( pp, @@ -421,21 +437,28 @@ where &running_oracle, challenge, ); + exit_span!(inner_span); if i > 0 { + let inner_span = entered_span!("mpcs::oracle_to_mktree"); let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); + exit_span!(inner_span); trees.push(running_tree); } if i < num_rounds - 1 { + let inner_span = entered_span!("mpcs::oracle_to_mktree"); last_sumcheck_message = sum_check_challenge_round(&mut eq, &mut running_evals, challenge); + exit_span!(inner_span); + let inner_span = entered_span!("mpcs::compute_inner_ext"); running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); + exit_span!(inner_span); roots.push(running_root); running_oracle = new_running_oracle; } else { @@ -480,9 +503,11 @@ where assert_eq!(basecode, new_running_oracle); } } - end_timer!(sumcheck_timer); + // end_timer!(sumcheck_timer); + exit_span!(sumcheck_span); } - end_timer!(timer); + // end_timer!(timer); + exit_span!(span); (trees, BasefoldCommitPhaseProof { sumcheck_messages, roots, diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 80ecf2535..9066ce349 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -1,6 +1,7 @@ pub mod arithmetic; pub mod expression; pub mod hash; +pub mod macros; pub mod parallel; pub mod plonky2_util; use ff::{Field, PrimeField}; diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index f2ede8680..8144e3733 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -153,10 +153,12 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { let mut evaluations = AdditiveVec::new(max_degree + 1); // sum for all round poly evaluations vector + let span = entered_span!("main_thread_collect_univariate_result"); for _ in 0..max_thread_id { let round_poly_coeffs = thread_based_transcript.read_field_element_exts(); evaluations += AdditiveVec(round_poly_coeffs); } + exit_span!(span); let span = entered_span!("main_thread_get_challenge"); transcript.append_field_element_exts(&evaluations.0);