Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fewer recursion shapes #1715

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 64 additions & 47 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub mod verify;

use std::{
borrow::Borrow,
collections::BTreeMap,
collections::{BTreeMap, HashMap},
env,
num::NonZeroUsize,
path::Path,
Expand All @@ -34,6 +34,7 @@ use std::{

use lru::LruCache;

use shapes::SP1ProofShape;
use tracing::instrument;

use p3_baby_bear::BabyBear;
Expand Down Expand Up @@ -102,7 +103,6 @@ const SHRINK_DEGREE: usize = 3;
const WRAP_DEGREE: usize = 9;

const CORE_CACHE_SIZE: usize = 5;
const COMPRESS_CACHE_SIZE: usize = 3;
pub const REDUCE_BATCH_SIZE: usize = 2;

// TODO: FIX
Expand Down Expand Up @@ -135,8 +135,7 @@ pub struct SP1Prover<C: SP1ProverComponents = DefaultProverComponents> {

pub recursion_cache_misses: AtomicUsize,

pub compress_programs:
Mutex<LruCache<SP1CompressWithVkeyShape, Arc<RecursionProgram<BabyBear>>>>,
pub compress_programs: HashMap<SP1CompressWithVkeyShape, Arc<RecursionProgram<BabyBear>>>,

pub compress_cache_misses: AtomicUsize,

Expand Down Expand Up @@ -188,14 +187,6 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
)
.expect("PROVER_CORE_CACHE_SIZE must be a non-zero usize");

let compress_cache_size = NonZeroUsize::new(
env::var("PROVER_COMPRESS_CACHE_SIZE")
.unwrap_or_else(|_| CORE_CACHE_SIZE.to_string())
.parse()
.unwrap_or(COMPRESS_CACHE_SIZE),
)
.expect("PROVER_COMPRESS_CACHE_SIZE must be a non-zero usize");

let core_shape_config = env::var("FIX_CORE_SHAPES")
.map(|v| v.eq_ignore_ascii_case("true"))
.unwrap_or(true)
Expand All @@ -220,14 +211,36 @@ impl<C: SP1ProverComponents> SP1Prover<C> {

let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().copied().collect());

let mut compress_programs = HashMap::new();
if let Some(config) = &recursion_shape_config {
SP1ProofShape::generate_compress_shapes(config, 2).for_each(|shape| {
let compress_shape = SP1CompressWithVkeyShape {
compress_shape: SP1CompressShape { proof_shapes: shape },
merkle_tree_height: merkle_tree.height,
};
let input = SP1CompressWithVKeyWitnessValues::dummy(
compress_prover.machine(),
&compress_shape,
);
let program = compress_program_from_input::<C>(
recursion_shape_config.as_ref(),
&compress_prover,
vk_verification,
&input,
);
let program = Arc::new(program);
compress_programs.insert(compress_shape, program);
});
}

Self {
core_prover,
compress_prover,
shrink_prover,
wrap_prover,
recursion_programs: Mutex::new(LruCache::new(core_cache_size)),
recursion_cache_misses: AtomicUsize::new(0),
compress_programs: Mutex::new(LruCache::new(compress_cache_size)),
compress_programs,
compress_cache_misses: AtomicUsize::new(0),
vk_root: root,
vk_merkle_tree: merkle_tree,
Expand Down Expand Up @@ -355,40 +368,17 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
&self,
input: &SP1CompressWithVKeyWitnessValues<InnerSC>,
) -> Arc<RecursionProgram<BabyBear>> {
let mut cache = self.compress_programs.lock().unwrap_or_else(|e| e.into_inner());
cache
.get_or_insert(input.shape(), || {
let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed);
tracing::debug!("compress cache miss, misses: {}", misses);
// Get the operations.
let builder_span = tracing::debug_span!("build compress program").entered();
let mut builder = Builder::<InnerConfig>::default();

// read the input.
let input = input.read(&mut builder);
// Verify the proof.
SP1CompressWithVKeyVerifier::verify(
&mut builder,
self.compress_prover.machine(),
input,
self.vk_verification,
PublicValuesOutputDigest::Reduce,
);
let operations = builder.into_operations();
builder_span.exit();

// Compile the program.
let compiler_span = tracing::debug_span!("compile compress program").entered();
let mut compiler = AsmCompiler::<InnerConfig>::default();
let mut program = compiler.compile(operations);
if let Some(recursion_shape_config) = &self.recursion_shape_config {
recursion_shape_config.fix_shape(&mut program);
}
let program = Arc::new(program);
compiler_span.exit();
program
})
.clone()
self.compress_programs.get(&input.shape()).map(Clone::clone).unwrap_or_else(|| {
let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed);
tracing::debug!("compress cache miss, misses: {}", misses);
// Get the operations.
Arc::new(compress_program_from_input::<C>(
self.recursion_shape_config.as_ref(),
&self.compress_prover,
self.vk_verification,
input,
))
})
}

pub fn shrink_program(
Expand Down Expand Up @@ -1217,6 +1207,33 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
}
}

pub fn compress_program_from_input<C: SP1ProverComponents>(
config: Option<&RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>>,
compress_prover: &C::CompressProver,
vk_verification: bool,
input: &SP1CompressWithVKeyWitnessValues<BabyBearPoseidon2>,
) -> RecursionProgram<BabyBear> {
let mut builder = Builder::<InnerConfig>::default();
// read the input.
let input = input.read(&mut builder);
// Verify the proof.
SP1CompressWithVKeyVerifier::verify(
&mut builder,
compress_prover.machine(),
input,
vk_verification,
PublicValuesOutputDigest::Reduce,
);
let operations = builder.into_operations();

// Compile the program.
let mut compiler = AsmCompiler::<InnerConfig>::default();
let mut program = compiler.compile(operations);
if let Some(config) = config {
config.fix_shape(&mut program);
}
program
}
#[cfg(any(test, feature = "export-tests"))]
pub mod tests {

Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/shapes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ impl SP1ProofShape {
)
}

pub fn generate_compress_shapes(
recursion_shape_config: &RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
reduce_batch_size: usize,
) -> impl Iterator<Item = Vec<ProofShape>> + '_ {
(1..=reduce_batch_size)
.flat_map(|batch_size| recursion_shape_config.get_all_shape_combinations(batch_size))
}

pub fn dummy_vk_map<'a>(
core_shape_config: &'a CoreShapeConfig<BabyBear>,
recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/recursion/circuit/src/machine/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub struct SP1CompressWitnessValues<SC: StarkGenericConfig> {

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SP1CompressShape {
proof_shapes: Vec<ProofShape>,
pub proof_shapes: Vec<ProofShape>,
}

impl<C, SC, A> SP1CompressVerifier<C, SC, A>
Expand Down
114 changes: 7 additions & 107 deletions crates/recursion/core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct RecursionShape {
}

pub struct RecursionShapeConfig<F, A> {
allowed_shapes: Vec<HashMap<String, usize>>,
pub allowed_shapes: Vec<HashMap<String, usize>>,
_marker: PhantomData<(F, A)>,
}

Expand Down Expand Up @@ -101,122 +101,22 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> Default
// Specify allowed shapes.
let allowed_shapes = [
[
(base_alu.clone(), 20),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(ext_alu.clone(), 20),
(base_alu.clone(), 19),
(mem_var.clone(), 19),
(poseidon2_wide.clone(), 17),
(mem_const.clone(), 16),
(exp_reverse_bits_len.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 19),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 19),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 16),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(ext_alu.clone(), 21),
(base_alu.clone(), 16),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 18),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(ext_alu.clone(), 21),
(base_alu.clone(), 20),
(mem_var.clone(), 20),
(poseidon2_wide.clone(), 18),
(mem_const.clone(), 17),
(exp_reverse_bits_len.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 20),
(ext_alu.clone(), 20),
(exp_reverse_bits_len.clone(), 18),
(base_alu.clone(), 16),
(mem_var.clone(), 19),
(poseidon2_wide.clone(), 16),
(mem_const.clone(), 18),
(poseidon2_wide.clone(), 18),
(exp_reverse_bits_len.clone(), 18),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
Expand Down