Skip to content

Commit

Permalink
feat: add edge case handling for batch_add
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnevadoc committed Jul 3, 2024
1 parent ec576f8 commit 05c4179
Showing 1 changed file with 114 additions and 9 deletions.
123 changes: 114 additions & 9 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// * each window overlap by 1 bit * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
Expand Down Expand Up @@ -54,7 +53,9 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
}
}

fn batch_add<C: CurveAffine>(
// Batch addition without edge case handling:
// Will panic if a point is the identity or if two points share the x coordinate.
fn batch_add_nonexceptional<C: CurveAffine>(
size: usize,
buckets: &mut [BucketAffine<C>],
points: &[SchedulePoint],
Expand Down Expand Up @@ -85,7 +86,9 @@ fn batch_add<C: CurveAffine>(
acc *= *z;
}

acc = acc.invert().unwrap();
acc = acc
.invert()
.expect("Attempted to invert 0 at batch_add_nmonexceptional");

for (
(
Expand All @@ -112,6 +115,94 @@ fn batch_add<C: CurveAffine>(
}
}

/// Batch addition with edge case handling.
fn batch_add_exceptional<C: CurveAffine>(
size: usize,
buckets: &mut [BucketAffine<C>],
points: &[SchedulePoint],
bases: &[Affine<C>],
) {
let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1
let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1
let mut acc = C::Base::ONE;

for (
(
SchedulePoint {
base_idx,
buck_idx,
sign,
},
t,
),
z,
) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
{
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}

if buckets[*buck_idx].x() == bases[*base_idx].x {
// y-coordinate matches:
// 1. y1 == y2 and sign = false or
// 2. y1 != y2 and sign = true
// => ( y1 == y2) xor !sign
// (This uses the fact that x1 == x2 and both points satisfy the curve eq.)
if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
// Doubling
let x_squared = bases[*base_idx].x.square();
*z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y
*t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2
acc *= *z;
continue;
}
// P + (-P)
buckets[*buck_idx].set_inf();
continue;
}
// Addition
*z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1
if *sign {
*t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
} else {
*t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
} // y2 - y1
acc *= *z;
}

acc = acc
.invert()
.expect("Some edge case has not been handled properly");

for (
(
SchedulePoint {
base_idx,
buck_idx,
sign,
},
t,
),
z,
) in points.iter().zip(t.iter()).zip(z.iter()).rev()
{
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}
let lambda = acc * t;
acc *= z; // update acc
let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result
if *sign {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
} else {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
} // y_result = lambda * (x1 - x_result) - y1
buckets[*buck_idx].set_x(&x);
}
}

#[derive(Debug, Clone, Copy)]
struct Affine<C: CurveAffine> {
x: C::Base,
Expand Down Expand Up @@ -207,6 +298,13 @@ impl<C: CurveAffine> BucketAffine<C> {
}
}

fn is_inf(&self) -> bool {
match self {
Self::None => true,
Self::Point(_) => false,
}
}

fn set_x(&mut self, x: &C::Base) {
match self {
Self::None => panic!("::set_x None"),
Expand All @@ -220,6 +318,13 @@ impl<C: CurveAffine> BucketAffine<C> {
Self::Point(ref mut a) => a.y = *y,
}
}

fn set_inf(&mut self) {
match self {
Self::None => {}
Self::Point(_) => *self = Self::None,
}
}
}

struct Schedule<C: CurveAffine> {
Expand Down Expand Up @@ -266,7 +371,7 @@ impl<C: CurveAffine> Schedule<C> {

fn execute(&mut self, bases: &[Affine<C>]) {
if self.ptr != 0 {
batch_add(self.ptr, &mut self.buckets, &self.set, bases);
batch_add_nonexceptional(self.ptr, &mut self.buckets, &self.set, bases);
self.ptr = 0;
self.set
.iter_mut()
Expand Down Expand Up @@ -473,7 +578,6 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(

#[cfg(test)]
mod test {

use std::ops::Neg;

use crate::bn256::{Fr, G1Affine, G1};
Expand Down Expand Up @@ -529,6 +633,7 @@ mod test {
}
}

#[cfg(test)]
fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
let points = (0..1 << max_k)
.map(|_| C::Curve::random(OsRng))
Expand All @@ -545,12 +650,12 @@ mod test {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];

let t0 = start_timer!(|| format!("cyclone k={}", k));
let e0 = super::best_multiexp_independent_points(scalars, points);
let t0 = start_timer!(|| format!("cyclone indep k={}", k));
let e0 = super::best_multiexp_independent_points(&scalars, &points);
end_timer!(t0);

let t1 = start_timer!(|| format!("older k={}", k));
let e1 = super::best_multiexp(scalars, points);
let e1 = super::best_multiexp(&scalars, &points);
end_timer!(t1);
assert_eq!(e0, e1);
}
Expand Down

0 comments on commit 05c4179

Please sign in to comment.