Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Booth encoding #106

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Cargo.lock
.vscode
**/*.html
.DS_Store
**/*.py

# script generated source code
src/bn256/fr/table.rs
Expand Down
3 changes: 2 additions & 1 deletion src/ff_ext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ pub trait Legendre {
// The legendre symbol returns 0 for 0
// and 1 for quadratic residues,
// we consider 0 a square hence quadratic residue.
self.legendre().ct_ne(&-1)
unimplemented!()
// self.legendre().ct_ne(&-1)
han0110 marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
293 changes: 267 additions & 26 deletions src/msm.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
use std::ops::Neg;

use ff::PrimeField;
use group::Group;
use pasta_curves::arithmetic::CurveAffine;

use crate::multicore;

fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window`e size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
// and remembering them as in classic signed digit encoding

let skip_bits = (window_index * window_size).saturating_sub(1);
let skip_bytes = skip_bits / 8;

// fill into a u32
let mut v: [u8; 4] = [0; 4];
for (dst, src) in v.iter_mut().zip(el.iter().skip(skip_bytes)) {
*dst = *src
}
let mut tmp = u32::from_le_bytes(v);

// pad with one 0 if windowing least significant window
if window_index == 0 {
tmp <<= 1;
}

// remove further bits
tmp >>= skip_bits - (skip_bytes * 8);
// apply the booth window
tmp &= (1 << (window_size + 1)) - 1;

let sign = tmp & (1 << window_size) == 0;

// div ceil by 2
tmp = (tmp + 1) >> 1;

// find the booth action index
if sign {
tmp as i32
} else {
((!tmp.saturating_sub(1) & ((1 << window_size) - 1)) as i32).neg()
han0110 marked this conversation as resolved.
Show resolved Hide resolved
}
}

pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

Expand All @@ -15,29 +61,9 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;

if skip_bytes >= 32 {
return 0;
}

let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}
let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1;

let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
tmp %= 1 << c;

tmp as usize
}

let segments = (256 / c) + 1;

for current_segment in (0..segments).rev() {
for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
*acc = acc.double();
}
Expand Down Expand Up @@ -73,12 +99,15 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
}
}

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];
let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; 1 << (c - 1)];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
if coeff != 0 {
buckets[coeff - 1].add_assign(base);
let coeff = get_booth_index(current_window as usize, c, coeff.as_ref());
if coeff.is_positive() {
buckets[coeff as usize - 1].add_assign(base);
}
if coeff.is_negative() {
buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg());
}
}

Expand Down Expand Up @@ -151,3 +180,215 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
acc
}
}

#[cfg(test)]
mod test {

use std::ops::Neg;

use crate::{
bn256::{Fr, G1Affine, G1},
multicore,
};
use ark_std::{end_timer, start_timer};
use ff::{Field, PrimeField};
use group::{Curve, Group};
use pasta_curves::arithmetic::CurveAffine;
use rand_core::OsRng;

// keeping older implementation it here for baseline comparision, debugging & benchmarking
fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

let num_threads = multicore::current_num_threads();
if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
multicore::scope(|scope| {
let chunk = coeffs.len() / num_threads;

for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
}
}

// keeping older implementation it here for baseline comparision, debugging & benchmarking
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;

if skip_bytes >= 32 {
return 0;
}

let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}

let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
tmp %= 1 << c;

tmp as usize
}

let segments = (256 / c) + 1;

for current_segment in (0..segments).rev() {
for _ in 0..c {
*acc = acc.double();
}

#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}

impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}

fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + a,
}
}
}

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
if coeff != 0 {
buckets[coeff - 1].add_assign(base);
}
}

// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().rev() {
running_sum = exp.add(running_sum);
*acc += &running_sum;
}
}
}

#[test]
fn test_booth_encoding() {
fn mul(scalar: &Fr, point: &G1Affine, window: usize) -> G1Affine {
let u = scalar.to_repr();
let n = Fr::NUM_BITS as usize / window + 1;

let table = (0..=1 << (window - 1))
.map(|i| point * Fr::from(i as u64))
.collect::<Vec<_>>();

let mut acc = G1::identity();
for i in (0..n).rev() {
for _ in 0..window {
acc = acc.double();
}

let idx = super::get_booth_index(i as usize, window, u.as_ref());

if idx.is_negative() {
acc += table[idx.unsigned_abs() as usize].neg();
}
if idx.is_positive() {
acc += table[idx.unsigned_abs() as usize];
}
}

acc.to_affine()
}

let (scalars, points): (Vec<_>, Vec<_>) = (0..10)
.map(|_| {
let scalar = Fr::random(OsRng);
let point = G1Affine::random(OsRng);
(scalar, point)
})
.unzip();

for window in 1..10 {
for (scalar, point) in scalars.iter().zip(points.iter()) {
let c0 = mul(scalar, point, window);
let c1 = point * scalar;
assert_eq!(c0, c1.to_affine());
}
}
}

#[test]
fn test_msm_cross() {
let min_k = 10;
let max_k = 22;

let points = (0..1 << max_k)
.map(|_| G1Affine::random(OsRng))
.collect::<Vec<_>>();

let scalars = (0..1 << max_k)
.map(|_| Fr::random(OsRng))
.collect::<Vec<_>>();

for k in min_k..=max_k {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];

let t0 = start_timer!(|| format!("w/ booth k={}", k));
let e0 = super::best_multiexp(scalars, points);
end_timer!(t0);

let t1 = start_timer!(|| format!("w/o booth k={}", k));
let e1 = best_multiexp(scalars, points);
end_timer!(t1);

assert_eq!(e0, e1);
}
}
}
Loading