Skip to content

Commit

Permalink
Improve: Generalize Rust benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Nov 12, 2024
1 parent c49abe3 commit 6f69eee
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 75 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ rpath = false # On some systems, setting this to false can help with optimiz
criterion = { version = "0.5.1" }
rand = { version = "0.8.5" }
half = { version = "2.4.1" }
num-traits = "0.2.19"
270 changes: 197 additions & 73 deletions scripts/bench_similarity.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(unused)]
use rand::Rng;
use std::ops::{AddAssign, Mul};

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use num_traits::{AsPrimitive, Num, NumCast};
use simsimd::SpatialSimilarity as SimSIMD;

const DIMENSIONS: usize = 1536;
Expand All @@ -10,95 +13,199 @@ pub(crate) fn generate_random_vector_f32(dim: usize) -> Vec<f32> {
}

pub(crate) fn generate_random_vector_i8(dim: usize) -> Vec<i8> {
(0..dim).map(|_| rand::thread_rng().gen_range(-128..=127)).collect()
(0..dim)
.map(|_| rand::thread_rng().gen_range(-128..=127))
.collect()
}

pub(crate) fn generate_random_vector_u8(dim: usize) -> Vec<u8> {
(0..dim).map(|_| rand::thread_rng().gen()).collect()
}

// Baseline functions for f32
pub(crate) fn baseline_cos_functional_f32(a: &[f32], b: &[f32]) -> Option<f32> {
pub(crate) fn baseline_cos_unrolled<T, Acc>(a: &[T], b: &[T]) -> Option<f32>
where
T: Num + Copy + NumCast + AsPrimitive<f32>,
Acc: Num + Copy + NumCast + AddAssign + 'static,
T: AsPrimitive<Acc>,
{
if a.len() != b.len() {
return None;
}

let (dot_product, norm_a, norm_b) = a
.iter()
.zip(b)
.map(|(a, b)| (a * b, a * a, b * b))
.fold((0.0, 0.0, 0.0), |acc, x| {
(acc.0 + x.0, acc.1 + x.1, acc.2 + x.2)
});
let mut i = 0;
let remainder = a.len() % 8;
let mut acc1 = Acc::zero();
let mut acc2 = Acc::zero();
let mut acc3 = Acc::zero();
let mut acc4 = Acc::zero();
let mut acc5 = Acc::zero();
let mut acc6 = Acc::zero();
let mut acc7 = Acc::zero();
let mut acc8 = Acc::zero();

let mut norm_a1 = Acc::zero();
let mut norm_a2 = Acc::zero();
let mut norm_b1 = Acc::zero();
let mut norm_b2 = Acc::zero();

while i < (a.len() - remainder) {
unsafe {
let a1 = *a.get_unchecked(i);
let a2 = *a.get_unchecked(i + 1);
let a3 = *a.get_unchecked(i + 2);
let a4 = *a.get_unchecked(i + 3);
let a5 = *a.get_unchecked(i + 4);
let a6 = *a.get_unchecked(i + 5);
let a7 = *a.get_unchecked(i + 6);
let a8 = *a.get_unchecked(i + 7);

let b1 = *b.get_unchecked(i);
let b2 = *b.get_unchecked(i + 1);
let b3 = *b.get_unchecked(i + 2);
let b4 = *b.get_unchecked(i + 3);
let b5 = *b.get_unchecked(i + 4);
let b6 = *b.get_unchecked(i + 5);
let b7 = *b.get_unchecked(i + 6);
let b8 = *b.get_unchecked(i + 7);

let a1_acc: Acc = NumCast::from(a1).unwrap();
let a2_acc: Acc = NumCast::from(a2).unwrap();
let a3_acc: Acc = NumCast::from(a3).unwrap();
let a4_acc: Acc = NumCast::from(a4).unwrap();
let a5_acc: Acc = NumCast::from(a5).unwrap();
let a6_acc: Acc = NumCast::from(a6).unwrap();
let a7_acc: Acc = NumCast::from(a7).unwrap();
let a8_acc: Acc = NumCast::from(a8).unwrap();

let b1_acc: Acc = NumCast::from(b1).unwrap();
let b2_acc: Acc = NumCast::from(b2).unwrap();
let b3_acc: Acc = NumCast::from(b3).unwrap();
let b4_acc: Acc = NumCast::from(b4).unwrap();
let b5_acc: Acc = NumCast::from(b5).unwrap();
let b6_acc: Acc = NumCast::from(b6).unwrap();
let b7_acc: Acc = NumCast::from(b7).unwrap();
let b8_acc: Acc = NumCast::from(b8).unwrap();

acc1 += a1_acc * b1_acc;
acc2 += a2_acc * b2_acc;
acc3 += a3_acc * b3_acc;
acc4 += a4_acc * b4_acc;
acc5 += a5_acc * b5_acc;
acc6 += a6_acc * b6_acc;
acc7 += a7_acc * b7_acc;
acc8 += a8_acc * b8_acc;

norm_a1 += a1_acc * a1_acc + a2_acc * a2_acc + a3_acc * a3_acc + a4_acc * a4_acc;
norm_b1 += b1_acc * b1_acc + b2_acc * b2_acc + b3_acc * b3_acc + b4_acc * b4_acc;

norm_a2 += a5_acc * a5_acc + a6_acc * a6_acc + a7_acc * a7_acc + a8_acc * a8_acc;
norm_b2 += b5_acc * b5_acc + b6_acc * b6_acc + b7_acc * b7_acc + b8_acc * b8_acc;
}

i += 8;
}

// Handle remaining elements
while i < a.len() {
unsafe {
let a_acc: Acc = NumCast::from(*a.get_unchecked(i)).unwrap();
let b_acc: Acc = NumCast::from(*b.get_unchecked(i)).unwrap();
acc1 += a_acc * b_acc;
norm_a1 += a_acc * a_acc;
norm_b1 += b_acc * b_acc;
}
i += 1;
}

Some(1.0 - (dot_product / (norm_a.sqrt() * norm_b.sqrt())))
let dot_product = acc1 + acc2 + acc3 + acc4 + acc5 + acc6 + acc7 + acc8;
let norm_a = norm_a1 + norm_a2;
let norm_b = norm_b1 + norm_b2;

let dot_product_f32: f32 = NumCast::from(dot_product).unwrap();
let norm_a_f32: f32 = NumCast::from(norm_a).unwrap();
let norm_b_f32: f32 = NumCast::from(norm_b).unwrap();

Some(1.0 - (dot_product_f32 / (norm_a_f32.sqrt() * norm_b_f32.sqrt())))
}

pub(crate) fn baseline_cos_unrolled_i8(a: &[i8], b: &[i8]) -> Option<f32> {
pub(crate) fn baseline_l2sq_unrolled<T, Acc>(a: &[T], b: &[T]) -> Option<f32>
where
T: Num + Copy + NumCast,
Acc: Num + Copy + NumCast + AddAssign + 'static,
T: AsPrimitive<Acc>,
{
if a.len() != b.len() {
return None;
}

let mut i = 0;
let remainder = a.len() % 8;
let mut acc1 = 0i32;
let mut acc2 = 0i32;
let mut acc3 = 0i32;
let mut acc4 = 0i32;
let mut acc5 = 0i32;
let mut acc6 = 0i32;
let mut acc7 = 0i32;
let mut acc8 = 0i32;

let mut norm_a1 = 0i32;
let mut norm_a2 = 0i32;
let mut norm_b1 = 0i32;
let mut norm_b2 = 0i32;

let mut acc1 = Acc::zero();
let mut acc2 = Acc::zero();
let mut acc3 = Acc::zero();
let mut acc4 = Acc::zero();
let mut acc5 = Acc::zero();
let mut acc6 = Acc::zero();
let mut acc7 = Acc::zero();
let mut acc8 = Acc::zero();

while i < (a.len() - remainder) {
unsafe {
let a1 = *a.get_unchecked(i) as i32;
let a2 = *a.get_unchecked(i + 1) as i32;
let a3 = *a.get_unchecked(i + 2) as i32;
let a4 = *a.get_unchecked(i + 3) as i32;
let a5 = *a.get_unchecked(i + 4) as i32;
let a6 = *a.get_unchecked(i + 5) as i32;
let a7 = *a.get_unchecked(i + 6) as i32;
let a8 = *a.get_unchecked(i + 7) as i32;

let b1 = *b.get_unchecked(i) as i32;
let b2 = *b.get_unchecked(i + 1) as i32;
let b3 = *b.get_unchecked(i + 2) as i32;
let b4 = *b.get_unchecked(i + 3) as i32;
let b5 = *b.get_unchecked(i + 4) as i32;
let b6 = *b.get_unchecked(i + 5) as i32;
let b7 = *b.get_unchecked(i + 6) as i32;
let b8 = *b.get_unchecked(i + 7) as i32;

acc1 += a1 * b1;
acc2 += a2 * b2;
acc3 += a3 * b3;
acc4 += a4 * b4;
acc5 += a5 * b5;
acc6 += a6 * b6;
acc7 += a7 * b7;
acc8 += a8 * b8;

norm_a1 += a1 * a1 + a2 * a2 + a3 * a3 + a4 * a4;
norm_b1 += b1 * b1 + b2 * b2 + b3 * b3 + b4 * b4;

norm_a2 += a5 * a5 + a6 * a6 + a7 * a7 + a8 * a8;
norm_b2 += b5 * b5 + b6 * b6 + b7 * b7 + b8 * b8;
let a1 = *a.get_unchecked(i);
let a2 = *a.get_unchecked(i + 1);
let a3 = *a.get_unchecked(i + 2);
let a4 = *a.get_unchecked(i + 3);
let a5 = *a.get_unchecked(i + 4);
let a6 = *a.get_unchecked(i + 5);
let a7 = *a.get_unchecked(i + 6);
let a8 = *a.get_unchecked(i + 7);

let b1 = *b.get_unchecked(i);
let b2 = *b.get_unchecked(i + 1);
let b3 = *b.get_unchecked(i + 2);
let b4 = *b.get_unchecked(i + 3);
let b5 = *b.get_unchecked(i + 4);
let b6 = *b.get_unchecked(i + 5);
let b7 = *b.get_unchecked(i + 6);
let b8 = *b.get_unchecked(i + 7);

let diff1 = <Acc as NumCast>::from(a1).unwrap() - <Acc as NumCast>::from(b1).unwrap();
let diff2 = <Acc as NumCast>::from(a2).unwrap() - <Acc as NumCast>::from(b2).unwrap();
let diff3 = <Acc as NumCast>::from(a3).unwrap() - <Acc as NumCast>::from(b3).unwrap();
let diff4 = <Acc as NumCast>::from(a4).unwrap() - <Acc as NumCast>::from(b4).unwrap();
let diff5 = <Acc as NumCast>::from(a5).unwrap() - <Acc as NumCast>::from(b5).unwrap();
let diff6 = <Acc as NumCast>::from(a6).unwrap() - <Acc as NumCast>::from(b6).unwrap();
let diff7 = <Acc as NumCast>::from(a7).unwrap() - <Acc as NumCast>::from(b7).unwrap();
let diff8 = <Acc as NumCast>::from(a8).unwrap() - <Acc as NumCast>::from(b8).unwrap();

acc1 += diff1 * diff1;
acc2 += diff2 * diff2;
acc3 += diff3 * diff3;
acc4 += diff4 * diff4;
acc5 += diff5 * diff5;
acc6 += diff6 * diff6;
acc7 += diff7 * diff7;
acc8 += diff8 * diff8;
}

i += 8;
}

let dot_product = acc1 + acc2 + acc3 + acc4 + acc5 + acc6 + acc7 + acc8;
let norm_a = (norm_a1 + norm_a2) as f32;
let norm_b = (norm_b1 + norm_b2) as f32;
// Handle remaining elements
while i < a.len() {
unsafe {
let a_val = <Acc as NumCast>::from(*a.get_unchecked(i)).unwrap();
let b_val = <Acc as NumCast>::from(*b.get_unchecked(i)).unwrap();
let diff = a_val - b_val;
acc1 += diff * diff;
}
i += 1;
}

let sum = acc1 + acc2 + acc3 + acc4 + acc5 + acc6 + acc7 + acc8;
let sum_f32: f32 = NumCast::from(sum).unwrap();

Some(1.0 - (dot_product as f32 / (norm_a.sqrt() * norm_b.sqrt())))
Some(sum_f32)
}

// Benchmarks
Expand All @@ -119,17 +226,25 @@ pub fn l2sq_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("Squared Euclidean Distance");

for i in 0..=5 {
group.bench_with_input(BenchmarkId::new("SimSIMD_f32", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD f32", i), &i, |b, _| {
b.iter(|| SimSIMD::sqeuclidean(&inputs_f32.0, &inputs_f32.1))
});
group.bench_with_input(BenchmarkId::new("SimSIMD_i8", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD i8", i), &i, |b, _| {
b.iter(|| SimSIMD::sqeuclidean(&inputs_i8.0, &inputs_i8.1))
});
group.bench_with_input(BenchmarkId::new("SimSIMD_u8", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD u8", i), &i, |b, _| {
b.iter(|| SimSIMD::sqeuclidean(&inputs_u8.0, &inputs_u8.1))
});
group.bench_with_input(BenchmarkId::new("Rust Procedural i8", i), &i, |b, _| {
b.iter(|| baseline_cos_unrolled_i8(&inputs_i8.0, &inputs_i8.1))
group.bench_with_input(BenchmarkId::new("Rust Unrolled i8", i), &i, |b, _| {
b.iter(|| baseline_l2sq_unrolled::<i8, i32>(&inputs_i8.0, &inputs_i8.1))
});

group.bench_with_input(BenchmarkId::new("Rust Unrolled u8", i), &i, |b, _| {
b.iter(|| baseline_l2sq_unrolled::<u8, u32>(&inputs_u8.0, &inputs_u8.1))
});

group.bench_with_input(BenchmarkId::new("Rust Unrolled f32", i), &i, |b, _| {
b.iter(|| baseline_l2sq_unrolled::<f32, f32>(&inputs_f32.0, &inputs_f32.1))
});
}
}
Expand All @@ -148,20 +263,29 @@ pub fn cos_benchmark(c: &mut Criterion) {
generate_random_vector_u8(DIMENSIONS),
);

let mut group = c.benchmark_group("SIMD Cosine");
let mut group = c.benchmark_group("Cosine Similarity");

for i in 0..=5 {
group.bench_with_input(BenchmarkId::new("SimSIMD_f32", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD f32", i), &i, |b, _| {
b.iter(|| SimSIMD::cosine(&inputs_f32.0, &inputs_f32.1))
});
group.bench_with_input(BenchmarkId::new("SimSIMD_i8", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD i8", i), &i, |b, _| {
b.iter(|| SimSIMD::cosine(&inputs_i8.0, &inputs_i8.1))
});
group.bench_with_input(BenchmarkId::new("SimSIMD_u8", i), &i, |b, _| {
group.bench_with_input(BenchmarkId::new("SimSIMD u8", i), &i, |b, _| {
b.iter(|| SimSIMD::cosine(&inputs_u8.0, &inputs_u8.1))
});
group.bench_with_input(BenchmarkId::new("Rust Procedural i8", i), &i, |b, _| {
b.iter(|| baseline_cos_unrolled_i8(&inputs_i8.0, &inputs_i8.1))

group.bench_with_input(BenchmarkId::new("Rust Unrolled i8", i), &i, |b, _| {
b.iter(|| baseline_cos_unrolled::<i8, i32>(&inputs_i8.0, &inputs_i8.1))
});

group.bench_with_input(BenchmarkId::new("Rust Unrolled u8", i), &i, |b, _| {
b.iter(|| baseline_cos_unrolled::<u8, u32>(&inputs_u8.0, &inputs_u8.1))
});

group.bench_with_input(BenchmarkId::new("Rust Unrolled f32", i), &i, |b, _| {
b.iter(|| baseline_cos_unrolled::<f32, f32>(&inputs_f32.0, &inputs_f32.1))
});
}
}
Expand Down

0 comments on commit 6f69eee

Please sign in to comment.