diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index 7366fb7b19..59e881f926 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -1,6 +1,8 @@ //! This module provides common utilities, traits and structures for group, //! field and polynomial arithmetic. +use std::cmp; + use super::multicore; pub use ff::Field; use group::{ @@ -25,6 +27,7 @@ where { } +// ASSUMES C::Scalar::Repr is little endian fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); @@ -36,6 +39,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut (f64::from(bases.len() as u32)).ln().ceil() as usize }; + // Group `bytes` into bits and take the `segment`th chunk of `c` bits fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { let skip_bits = segment * c; let skip_bytes = skip_bits / 8; @@ -56,9 +60,35 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut tmp as usize } - let segments = (256 / c) + 1; + // Ideally `segments` should be calculated from the max number of bits among all scalars. But this requires a scan of all scalars, so we don't implement it for now. + let segments = (C::Scalar::NUM_BITS as usize + c - 1) / c; + + // this can be optimized + let mut coeffs_in_segments = Vec::with_capacity(segments); + // track what is the last segment where we actually have nonzero bits, so we completely skip buckets where the scalar bits for all coeffs are 0 + let mut max_nonzero_segment = None; + for current_segment in 0..segments { + let coeff_segments: Vec<_> = coeffs + .iter() + .map(|coeff| { + let c_bits = get_at::(current_segment, c, coeff); + if c_bits != 0 { + max_nonzero_segment = Some(current_segment); + } + c_bits + }) + .collect(); + coeffs_in_segments.push(coeff_segments); + } - for current_segment in (0..segments).rev() { + if max_nonzero_segment.is_none() { + return; + } + for coeffs_seg in coeffs_in_segments + .into_iter() + .take(max_nonzero_segment.unwrap() + 1) + .rev() + { for _ in 0..c { *acc = acc.double(); } @@ -96,9 +126,10 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; - for (coeff, base) in coeffs.iter().zip(bases.iter()) { - let coeff = get_at::(current_segment, c, coeff); + let mut max_bits = 0; + for (coeff, base) in coeffs_seg.into_iter().zip(bases.iter()) { if coeff != 0 { + max_bits = cmp::max(max_bits, coeff); buckets[coeff - 1].add_assign(base); } } @@ -108,7 +139,7 @@ fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut // (a) + b + // ((a) + b) + c let mut running_sum = C::Curve::identity(); - for exp in buckets.into_iter().rev() { + for exp in buckets.into_iter().take(max_bits).rev() { running_sum = exp.add(running_sum); *acc = *acc + &running_sum; }