Skip to content

Commit

Permalink
Add a benchmark for two_cap_with_syn to JAX and NumPy benchmarks, and…
Browse files Browse the repository at this point in the history
… the JAX grad benchmark.

This does so by simply adding yet more configurations, which is fine, and some slight translation to test two_cap/two_cap_with_syn. Although we could benchmark one_cap we do not look for that at the moment.

PiperOrigin-RevId: 690872015
  • Loading branch information
Rob Schonberger authored and copybara-github committed Oct 29, 2024
1 parent cb1e9fd commit c80749a
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,16 @@ def bench_numpy_in_slices(state: google_benchmark.State):
@google_benchmark.option.measure_process_cpu_time()
@google_benchmark.option.use_real_time()
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
@google_benchmark.option.arg_names(['segment_sample_length'])
@google_benchmark.option.args([220])
@google_benchmark.option.args([2205])
@google_benchmark.option.args([22050])
@google_benchmark.option.args([44100])
@google_benchmark.option.args([220500])
@google_benchmark.option.arg_names([
'segment_sample_length',
'ihc_style',
])
@google_benchmark.option.args_product(
[
[220, 2205, 22050, 44100, 220500],
[0, 1],
],
)
def bench_numpy(state: google_benchmark.State):
"""Benchmark the numpy version of carfac.
Expand All @@ -109,7 +113,12 @@ def bench_numpy(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
random_seed = 1
ihc_style = 'two_cap'
if state.range(1) == 0:
ihc_style = 'two_cap'
elif state.range(1) == 1:
ihc_style = 'two_cap_with_syn'
else:
raise ValueError('Invalid ihc_style')
cfp = carfac_np.design_carfac(ihc_style=ihc_style)

carfac_np.carfac_init(cfp)
Expand All @@ -134,16 +143,25 @@ def bench_numpy(state: google_benchmark.State):
@google_benchmark.option.measure_process_cpu_time()
@google_benchmark.option.use_real_time()
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
@google_benchmark.option.arg_names(['segment_sample_length'])
@google_benchmark.option.range_multiplier(2)
@google_benchmark.option.range(128, 4096)
@google_benchmark.option.arg_names(['segment_sample_length', 'ihc_style'])
@google_benchmark.option.args_product(
[
[128, 256, 512, 1024, 2048, 4096],
[0, 1],
],
)
def bench_jax_grad(state: google_benchmark.State):
"""Benchmark JAX Value and Gradient function on Carfac.
Args:
state: The Benchmark state for this run.
"""
ihc_style = 'two_cap'
if state.range(1) == 0:
ihc_style = 'two_cap'
elif state.range(1) == 1:
ihc_style = 'two_cap_with_syn'
else:
raise ValueError('Invalid ihc_style')
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].ihc.ihc_style = ihc_style
Expand Down Expand Up @@ -313,11 +331,19 @@ def bench_jax_in_slices(state: google_benchmark.State):
@google_benchmark.option.measure_process_cpu_time()
@google_benchmark.option.use_real_time()
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
@google_benchmark.option.arg_names(
['jax_chunked_uncompiled', 'segment_sample_length', 'use_delay_buffer']
)
@google_benchmark.option.arg_names([
'jax_chunked_uncompiled',
'segment_sample_length',
'use_delay_buffer',
'ihc_style',
])
@google_benchmark.option.args_product(
[[0, 1, 2], [220, 2205, 22050, 44100, 220500, 2205000], [False, True]],
[
[0, 1, 2],
[220, 2205, 22050, 44100, 220500, 2205000],
[False, True],
[0, 1],
],
)
def bench_jax(state: google_benchmark.State):
"""Benchmark the JAX version of carfac.
Expand All @@ -335,7 +361,12 @@ def bench_jax(state: google_benchmark.State):
state: the benchmark state for this execution run.
"""
# Inits JAX version
ihc_style = 'two_cap'
if state.range(3) == 0:
ihc_style = 'two_cap'
elif state.range(3) == 1:
ihc_style = 'two_cap_with_syn'
else:
raise ValueError('Invalid ihc_style')
random_seed = 1
params_jax = carfac_jax.CarfacDesignParameters()
params_jax.ears[0].car.use_delay_buffer = state.range(2)
Expand Down

0 comments on commit c80749a

Please sign in to comment.